ichou1のブログ

主に音声認識、時々、データ分析のことを書く

Kerasメモ(XLNet)その3

前回の続き。

AttentionレイヤがBERTとどう変わるのか見てみる。
work-in-progress.hatenablog.com

前々回のmodel.summaryの抜粋

model.summary
Layer (type)            Output Shape          Param #  Connected to                
======================= ===================== ======== ============================
Attention-1             (None, 3, 768)        2949120  Embed-Token-Masked[0][0]    
(RelativePartial-                                      Embed-Token-Masked[0][0]    
 MultiHeadAttention                                    Memory-0[0][0]              
                                                       Embed-Segment-1[0][0]       
                                                       Embed-Segment-1[0][1]       
                                                       Embed-Pos[0][0]             
                                                       Relative-Bias-1[0][0]       
                                                       Relative-Bias-1[0][1]       
                                                       Segment-Bias-1[0][0]        
                                                       Permutation[0][0]           
keras_xlnet/attention.py

デフォルトで"use_bias"はFalse
パラメータ数は「2,949,120」(=768 * 768 * 5)
Query、Key、Value、Outputに加え、Relative(PositionalEmbeddingレイヤの出力)用のパラメータをもつ。

class RelativePartialMultiHeadSelfAttention(keras.layers.Layer):

    def __init__(self,
                 ...
                 use_bias=False,
                 ...

    def build(self, input_shape):
        # self.units: 768
        self.kernel = self.add_weight(
            shape=(self.units, self.units * 5),
            ... 
        )

    def call(self, inputs, mask=None, training=None):
        ...
        kernel_q  = self.kernel[:,                : self.units]
        kernel_kv = self.kernel[:, self.units     : self.units * 3]
        kernel_r  = self.kernel[:, self.units * 3 : self.units * 4]
        kernel_o  = self.kernel[:, self.units * 4 : self.units * 5]
        ...
(参考)BERT

デフォルトで"use_bias"はTrue
パラメータ数は「2,362,368」(=(768 * 768 * 4 ) + (768 * 4))

Layer (type)            Output Shape          Param #  Connected to                
======================= ===================== ======== ============================
Encoder-1-              (None, 3, 768)        2362368  Embedding-Norm[0][0]
MultiHeadSelfAttention
(MultiHeadAttention)

ソースコードをもとに、推論時のattentionレイヤ内部処理を図示してみる。
(headごとのテンソル分割、統合は省略している)
f:id:ichou1:20191118221309p:plain

Queryと重みの内積を計算
# inputs: Embed-Token-Masked[0][0]  (batch_size, seq_len, units)
# kernel_q : (units, units)
w_q = K.dot(inputs, kernel_q)
# --> (batch_size, seq_len, units)
Key、Valueと重みの内積を計算
# memories: Memory-0[0][0]  (batch_size, prev_len, units)
# content: Embed-Token-Masked[0][0]  (batch_size, seq_len, units)
full = K.concatenate([memories, content], axis=1)
# --> (batch_size, prev_len + seq_len, units)

# kernel_kv: (units, units * 2)
w_kv = K.dot(full, kernel_kv)
# --> (batch_size, prev_len + seq_len, units * 2)
PositionalEmbeddingと重みの内積を計算
# relatives:  Embed-Pos[0][0]  (batch_size, prev_len + seq_len + 1, units)
# kernel_r: (units, units)
w_r = K.dot(relatives, kernel_r)
# --> (batch_size, prev_len + seq_len + 1, units)


QeuryとKeyのdot product
w_k = w_kv[:, :, :self.units]  
# --> (batch_size, prev_len + seq_len, units)

# bias_context: Relative-Bias-*[0][0]
w_qc = K.bias_add(w_q, bias_context)

# headごとにtensorを分割
w_qc = self._reshape_to_batches(w_qc)
# --> (batch_size * n_head, seq_len, units_head)

w_k = self._reshape_to_batches(w_k)
# --> (batch_size * n_head, prev_len + seq_len, units_head)

# QeuryとKeyのdot product
a_context = K.batch_dot(w_qc, w_k, axes=2)
# --> (batch_size * n_head, seq_len, prev_len + seq_len)
QeuryとPositionalEmbeddingのdot product
# bias_relative: Relative-Bias-*[0][1]
w_qr = K.bias_add(w_q, bias_relative)

# headごとにtensorを分割
w_qr = self._reshape_to_batches(w_qr)
# --> (batch_size * n_head, seq_len, units_head)
w_r = self._reshape_to_batches(w_r)
# --> (batch_size * n_head, prev_len + seq_len + 1, units_head)

# QeuryとPositionalEmbeddingのdot product
a_relative = K.batch_dot(w_qr, w_r, axes=2)
# --> (batch_size * n_head, seq_len, prev_len + seq_len + 1)

# 相対シフト
a_relative = self._relative_shift(               
    a_relative,
    key_len_expected=K.shape(a_context)[-1],
)
# --> (batch_size * n_head, seq_len, prev_len + seq_len)
QeuryとRelativeSegmentEmbeddingsのdot product

「RelativeSegmentEmbeddings」レイヤを理解できていないので省略。
出力として「a_segment」が得られる。



3つのdot productで得られた「a_context 」、「a_relative」、 「a_segment」を足し合わせる。

att = (a_context + a_relative + a_segment) / K.sqrt(K.constant(self.units_head, dtype=K.floatx()))
# --> (batch_size * n_head, seq_len, prev_len + seq_len)

この時点の出力の形状を比較

BERT

(batch_size * n_head, seq_len, seq_len)

XLNet

(batch_size * n_head, seq_len, prev_len + seq_len)

次いで、Scale処理

exp = K.exp(att - K.max(att, axis=-1, keepdims=True))
# --> (batch_size * n_head, seq_len, prev_len + seq_len)

「PermutationMask」レイヤの出力を使った計算。

# permutation: Permutation[0][0] (Content mask)
# q_len: seq_len
# k_len: prev_len + seq_len 
permutation = K.tile(K.expand_dims(permutation, axis=1), [1, self.num_head, 1, 1])
permutation = K.reshape(permutation, (-1, q_len, k_len))
exp *= permutation
# --> (batch_size * n_head, seq_len, prev_len + seq_len)

mask後、softmax

att = exp / (K.sum(exp, axis=-1, keepdims=True) + K.epsilon())
# --> (batch_size * n_head, seq_len, prev_len + seq_len)

Valueに関するdot product

w_v = w_kv[:, :, self.units:]          # (units, units)
w_v = self._reshape_to_batches(w_v)    # (batch_size * n_head, prev_len + seq_len, units_head)
w_o = K.batch_dot(att, w_v)
# --> (batch_size * n_head, seq_len, units_head)

ここで、テンソルの形状から"prev_len"が消える。

各headを統合する。

w_o = self._reshape_from_batches(w_o)
# --> (batch_size, seq_len, units)

Outputに関するdot product

w_o = K.dot(w_o, kernel_o)
# --> (batch_size, seq_len, units)

これにbiasを加え、activationを通したものがAttentionレイヤの出力になる。
Outputの形状はBERTと同じ。

次回に続く。