Kerasメモ(XLNet)その4
前回の続き。
PositionalEmbeddingレイヤを見てみる。
keras_xlnet/xlnet.py
def build_xlnet(...): ... pos_embed = PositionalEmbedding( output_dim=units, clamp_len=clamp_len, directional=attention_type == 'uni', name='Embed-Pos', )([token_embed, memories[0]])
keras_xlnet/position_embed.py
class PositionalEmbedding(keras.layers.Layer): ... def call(self, inputs, **kwargs): q_len, m_len = K.shape(inputs[0])[1], K.shape(inputs[1])[1] # q_len: 3 # m_len: prev_len
仮に、「m_len」の値が「6」だったとする。
k_len = q_len + m_len # --> k_len = 9 start, stop = k_len, -1 inputs = K.tile( K.expand_dims(K.arange(start, stop, -1, dtype=K.floatx()), axis=0), [K.shape(inputs[0])[0], 1], )
inputsの形状は(batch_size, 10)となる。
ここで「10」(= k_len + 1)となっていることに注意する。
inputs
[[9., 8., 7., 6., 5., 4., 3., 2., 1., 0.] ... [9., 8., 7., 6., 5., 4., 3., 2., 1., 0.]]
後続の計算。
inputs = K.expand_dims(inputs, axis=-1) # --> (batch_size, 10, 1) output_dim = K.cast(self.output_dim, K.floatx()) # --> 768.0 ranges = K.expand_dims(K.arange(0.0, self.output_dim, 2.0), axis=0) / output_dim # --> (1, 384) # value: 0.0, ..., 0.99739583 inverse = 1.0 / K.pow(10000.0, ranges) # value: 1.0, ..., 0.000102427522 #broadcast positions = inputs * inverse # --> (batch_size, 10, 384)
positionsの中身。
[[9. , 8.78670089, 8.57845695, ..., 0.00096715, 0.00094423, 0.00092185] [8. , 7.81040079, 7.62529507, ..., 0.00085969, 0.00083931, 0.00081942] [7. , 6.83410069, 6.67213318, ..., 0.00075223, 0.0007344 , 0.00071699] ... [1. , 0.9763001 , 0.95316188, ..., 0.00010746, 0.00010491, 0.00010243] [0. , 0. ,0. , ..., 0. , 0. , 0. ]]
positionsのsinとcosを計算。
positions_sin = K.sin(positions) positions_cos = K.cos(positions)
positions_sinの中身。
[[ 0.41211849, 0.59565196, 0.74884726, 0.86723886, ..., 0.00094423, 0.00092185] [ 0.98935825, 0.99905051, 0.97396499, 0.91735771, ..., 0.00083931, 0.00081942] [ 0.6569866 , 0.5234674 , 0.37921509, 0.22877486, ..., 0.0007344 , 0.00071699] [-0.2794155 , -0.41267124, -0.53475181, -0.6440288, ..., 0.00062948, 0.00061457] [-0.95892427, -0.98573469, -0.99857347, -0.99822869, ..., 0.00052457, 0.00051214] [-0.7568025 , -0.69153198, -0.62181248, -0.54860557, ..., 0.00041966, 0.00040971] [ 0.14112001, 0.21109235, 0.27837998, 0.34278182, ..., 0.00031474, 0.00030728] [ 0.90929743, 0.92799403, 0.94423677, 0.95814438, ..., 0.00020983, 0.00020486] [ 0.84147098, 0.82843076, 0.81525065, 0.8019618 , ..., 0.00010491, 0.00010243] [ 0. , 0. , 0. , 0. , ..., 0. , 0. ]]
positions_cosの中身。
[[-0.91113026, -0.80324264, -0.66274263, -0.49789231, ..., 0.99999955, 0.99999958] [-0.14550003, 0.04356705, 0.22669848, 0.39806385, ..., 0.99999965, 0.99999966] [ 0.75390225, 0.8520457 , 0.92530855, 0.97347936, ..., 0.99999973, 0.99999974] [ 0.96017029, 0.91088004, 0.84500917, 0.76500125, ..., 0.9999998 , 0.99999981] [ 0.28366219, 0.1683066 , 0.05339503, -0.05949362, ..., 0.99999986, 0.99999987] [-0.65364362, -0.72234585, -0.78316616, -0.83608129, ..., 0.99999991, 0.99999992] [-0.9899925 , -0.97746612, -0.96047102, -0.93941504, ..., 0.99999995, 0.99999995] [-0.41614684, -0.37259506, -0.32926724, -0.28628544, ..., 0.99999998, 0.99999998] [ 0.54030231, 0.56009149, 0.57910826, 0.59737533, ..., 0.99999999, 0.99999999] [ 1. , 1. , 1. , 1. , ..., 1. , 1. ]]
2つを結合。
position_concat = K.concatenate([positions_sin, positions_cos], axis=-1) return position_concat # --> (batch_size, 10, 786)
position_concat[0], position_concat[1]
position_concat[2], position_concat[3]
position_concat[4], position_concat[5]
position_concat[6], position_concat[7]
position_concat[8], position_concat[9]
これが、attentionレイヤの順伝播におけるinputsの一つである"relatives"に該当する。
keras_xlnet/attention.py
class RelativePartialMultiHeadSelfAttention(keras.layers.Layer): ... def call(self, inputs, mask=None, training=None): (inputs, content, memories, segment_mat, segment_embed, relatives, # <-- bias_context, bias_relative, bias_segment, permutation) = inputs
relativesに注目して計算過程を追ってみる。
(前回の再掲)PositionalEmbeddingと重みの内積を計算
# relatives: Embed-Pos[0][0] (batch, prev_len + seq_len + 1, units) # kernel_r: (units, units) w_r = K.dot(relatives, kernel_r) # --> (batch, prev_len + seq_len + 1, units)
(前回の再掲)QeuryとPositionalEmbeddingのdot product
# bias_relative: Relative-Bias-*[0][1] w_qr = K.bias_add(w_q, bias_relative) # headごとにtensorを分割 w_qr = self._reshape_to_batches(w_qr) # (batch * n_head, seq_len, units_head) w_r = self._reshape_to_batches(w_r) # (batch * n_head, prev_len + seq_len + 1, units_head) # QeuryとPositionalEmbeddingのdot product a_relative = K.batch_dot(w_qr, w_r, axes=2) # --> a_relative: (batch * n_head, seq_len, prev_len + seq_len + 1)
# a_relative: (batch * n_head, seq_len, prev_len + seq_len + 1) # key_len_expected: prev_len + seq_len a_relative = self._relative_shift( a_relative, key_len_expected=K.shape(a_context)[-1], ) # --> a_relative: (batch * n_head, seq_len, prev_len + seq_len)
_relative_shift関数でやっていること。
(次元は0始まりとする)
- reshape(1次元と2次元を入れ替えた形状に変更)
- スライシング(1次元目のindex=0を捨てる)
- reshape(1次元(旧2次元)と2次元(旧1次元)を入れ替えた形状に変更)
class RelativePartialMultiHeadSelfAttention(keras.layers.Layer): ... def _relative_shift(x, key_len_expected=-1): batch_size, q_len, k_len = K.shape(x)[0], K.shape(x)[1], K.shape(x)[2] # q_len : 3 # k_len : 10 (= q_len + prev_len + 1) # --> x: (batch * n_head, 3, 10) x = K.reshape(x, (K.shape(x)[0], K.shape(x)[2], K.shape(x)[1])) # --> (batch * n_head, 10, 3) x = x[:, 1:, :] # --> # (batch * n_head, 9, 3) x = K.reshape(x, (batch_size, q_len, k_len - 1)) # --> (batch * n_head, 3, 9) x = tf.slice(x, (0, 0, 0), (-1, -1, key_len_expected)) # --> (batch * n_head, 3, 9) return x
処理の前後でどう変わるのか、テストコードで確認してみる。
元データ
# shape : (3, 10) array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]])
reshape
# shape : (10, 3) array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11], [12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23], [24, 25, 26], [27, 28, 29]])
slice([1:,:])
# shape : (9, 3) array([[ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11], [12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23], [24, 25, 26], [27, 28, 29]])
reshape
# shape : (3, 9) array([[ 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20], [21, 22, 23, 24, 25, 26, 27, 28, 29]])
最初の3つのデータ"0"、"1"、"2"が捨てられている。
次回に続く。