ichou1のブログ

主に音声認識、時々、データ分析のことを書く

Kerasメモ(XLNet)その4

前回の続き。

PositionalEmbeddingレイヤを見てみる。

keras_xlnet/xlnet.py
def build_xlnet(...):
    ...
    pos_embed = PositionalEmbedding(
        output_dim=units,
        clamp_len=clamp_len,
        directional=attention_type == 'uni',
        name='Embed-Pos',
    )([token_embed, memories[0]])
keras_xlnet/position_embed.py
class PositionalEmbedding(keras.layers.Layer):
    ...
    def call(self, inputs, **kwargs):
        q_len, m_len = K.shape(inputs[0])[1], K.shape(inputs[1])[1]
        # q_len: 3
        # m_len: prev_len

仮に、「m_len」の値が「6」だったとする。

        k_len = q_len + m_len
        # --> k_len = 9
        start, stop = k_len, -1

        inputs = K.tile(
            K.expand_dims(K.arange(start, stop, -1, dtype=K.floatx()), axis=0),
            [K.shape(inputs[0])[0], 1],
        )

inputsの形状は(batch_size, 10)となる。
ここで「10」(= k_len + 1)となっていることに注意する。

inputs
[[9., 8., 7., 6., 5., 4., 3., 2., 1., 0.]
  ...
 [9., 8., 7., 6., 5., 4., 3., 2., 1., 0.]]

後続の計算。

        inputs = K.expand_dims(inputs, axis=-1)
        # --> (batch_size, 10, 1)

        output_dim = K.cast(self.output_dim, K.floatx())
        # --> 768.0

        ranges = K.expand_dims(K.arange(0.0, self.output_dim, 2.0), axis=0) / output_dim
        # --> (1, 384)
        # value: 0.0, ..., 0.99739583

        inverse = 1.0 / K.pow(10000.0, ranges)
        # value: 1.0, ..., 0.000102427522

        #broadcast        
        positions = inputs * inverse
        # --> (batch_size, 10, 384)

positionsの中身。

[[9.        , 8.78670089, 8.57845695, ..., 0.00096715, 0.00094423, 0.00092185]
 [8.        , 7.81040079, 7.62529507, ..., 0.00085969, 0.00083931, 0.00081942]
 [7.        , 6.83410069, 6.67213318, ..., 0.00075223, 0.0007344 , 0.00071699]
 ...
 [1.        , 0.9763001 , 0.95316188, ..., 0.00010746, 0.00010491, 0.00010243]
 [0.        , 0.         ,0.        , ..., 0.        , 0.        , 0.        ]]

positionsのsinとcosを計算。

        positions_sin = K.sin(positions)
        positions_cos = K.cos(positions)

positions_sinの中身。

[[ 0.41211849,  0.59565196,  0.74884726,  0.86723886, ..., 0.00094423, 0.00092185]
 [ 0.98935825,  0.99905051,  0.97396499,  0.91735771, ..., 0.00083931, 0.00081942] 
 [ 0.6569866 ,  0.5234674 ,  0.37921509,  0.22877486, ..., 0.0007344 , 0.00071699] 
 [-0.2794155 , -0.41267124, -0.53475181, -0.6440288,  ..., 0.00062948, 0.00061457]
 [-0.95892427, -0.98573469, -0.99857347, -0.99822869, ..., 0.00052457, 0.00051214]
 [-0.7568025 , -0.69153198, -0.62181248, -0.54860557, ..., 0.00041966, 0.00040971]
 [ 0.14112001,  0.21109235,  0.27837998,  0.34278182, ..., 0.00031474, 0.00030728]
 [ 0.90929743,  0.92799403,  0.94423677,  0.95814438, ..., 0.00020983, 0.00020486]
 [ 0.84147098,  0.82843076,  0.81525065,  0.8019618 , ..., 0.00010491, 0.00010243]
 [ 0.        ,  0.        ,  0.        ,  0.        , ..., 0.        , 0.        ]]

positions_cosの中身。

[[-0.91113026, -0.80324264, -0.66274263, -0.49789231, ..., 0.99999955, 0.99999958] 
 [-0.14550003,  0.04356705,  0.22669848,  0.39806385, ..., 0.99999965, 0.99999966]
 [ 0.75390225,  0.8520457 ,  0.92530855,  0.97347936, ..., 0.99999973, 0.99999974]
 [ 0.96017029,  0.91088004,  0.84500917,  0.76500125, ..., 0.9999998 , 0.99999981]
 [ 0.28366219,  0.1683066 ,  0.05339503, -0.05949362, ..., 0.99999986, 0.99999987]
 [-0.65364362, -0.72234585, -0.78316616, -0.83608129, ..., 0.99999991, 0.99999992] 
 [-0.9899925 , -0.97746612, -0.96047102, -0.93941504, ..., 0.99999995, 0.99999995]
 [-0.41614684, -0.37259506, -0.32926724, -0.28628544, ..., 0.99999998, 0.99999998]
 [ 0.54030231,  0.56009149,  0.57910826,  0.59737533, ..., 0.99999999, 0.99999999]
 [ 1.        ,  1.        ,  1.        ,  1.        , ..., 1.        , 1.        ]]

2つを結合。

        position_concat = K.concatenate([positions_sin, positions_cos], axis=-1)
        return position_concat
        # --> (batch_size, 10, 786)
position_concat[0], position_concat[1]

f:id:ichou1:20191122215447p:plainf:id:ichou1:20191122215457p:plain

position_concat[2], position_concat[3]

f:id:ichou1:20191122215507p:plainf:id:ichou1:20191122215630p:plain

position_concat[4], position_concat[5]

f:id:ichou1:20191122215638p:plainf:id:ichou1:20191122215651p:plain

position_concat[6], position_concat[7]

f:id:ichou1:20191122215817p:plainf:id:ichou1:20191122215824p:plain

position_concat[8], position_concat[9]

f:id:ichou1:20191122215831p:plainf:id:ichou1:20191122215902p:plain

これが、attentionレイヤの順伝播におけるinputsの一つである"relatives"に該当する。

keras_xlnet/attention.py
class RelativePartialMultiHeadSelfAttention(keras.layers.Layer):
    ...
    def call(self, inputs, mask=None, training=None):
        (inputs,
         content,
         memories,
         segment_mat,
         segment_embed,
         relatives,  # <--
         bias_context,
         bias_relative,
         bias_segment,
         permutation) = inputs

relativesに注目して計算過程を追ってみる。

(前回の再掲)PositionalEmbeddingと重みの内積を計算
# relatives:  Embed-Pos[0][0]  (batch, prev_len + seq_len + 1, units)
# kernel_r: (units, units)
w_r = K.dot(relatives, kernel_r)
# --> (batch, prev_len + seq_len + 1, units)
(前回の再掲)QeuryとPositionalEmbeddingのdot product
# bias_relative: Relative-Bias-*[0][1]
w_qr = K.bias_add(w_q, bias_relative)

# headごとにtensorを分割
w_qr = self._reshape_to_batches(w_qr)  # (batch * n_head, seq_len, units_head)
w_r = self._reshape_to_batches(w_r)    # (batch * n_head, prev_len + seq_len + 1, units_head)

# QeuryとPositionalEmbeddingのdot product
a_relative = K.batch_dot(w_qr, w_r, axes=2)
# --> a_relative: (batch * n_head, seq_len, prev_len + seq_len + 1)
# a_relative: (batch * n_head, seq_len, prev_len + seq_len + 1)
# key_len_expected: prev_len + seq_len
a_relative = self._relative_shift(               
    a_relative,
    key_len_expected=K.shape(a_context)[-1],
)
# --> a_relative: (batch * n_head, seq_len, prev_len + seq_len)

_relative_shift関数でやっていること。
(次元は0始まりとする)

  1. reshape(1次元と2次元を入れ替えた形状に変更)
  2. スライシング(1次元目のindex=0を捨てる)
  3. reshape(1次元(旧2次元)と2次元(旧1次元)を入れ替えた形状に変更)
class RelativePartialMultiHeadSelfAttention(keras.layers.Layer):
    ...
    def _relative_shift(x, key_len_expected=-1):

        batch_size, q_len, k_len = K.shape(x)[0], K.shape(x)[1], K.shape(x)[2]
        # q_len : 3
        # k_len : 10 (= q_len + prev_len + 1)
        # --> x: (batch * n_head, 3, 10)

        x = K.reshape(x, (K.shape(x)[0], K.shape(x)[2], K.shape(x)[1]))
        # --> (batch * n_head, 10, 3)

        x = x[:, 1:, :]
        # --> # (batch * n_head, 9, 3)

        x = K.reshape(x, (batch_size, q_len, k_len - 1))
        # --> (batch * n_head, 3, 9)

        x = tf.slice(x, (0, 0, 0), (-1, -1, key_len_expected))
        # --> (batch * n_head, 3, 9)

        return x

処理の前後でどう変わるのか、テストコードで確認してみる。

元データ
# shape : (3, 10)
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]])
reshape
# shape : (10, 3)
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17],
       [18, 19, 20],
       [21, 22, 23],
       [24, 25, 26],
       [27, 28, 29]])
slice([1:,:])
# shape : (9, 3)
array([[ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17],
       [18, 19, 20],
       [21, 22, 23],
       [24, 25, 26],
       [27, 28, 29]])
reshape
# shape : (3, 9)
array([[ 3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20],
       [21, 22, 23, 24, 25, 26, 27, 28, 29]])

最初の3つのデータ"0"、"1"、"2"が捨てられている。

次回に続く。