Kerasメモ(XLNet)その3
前回の続き。
AttentionレイヤがBERTとどう変わるのか見てみる。
work-in-progress.hatenablog.com
前々回のmodel.summaryの抜粋
model.summary
Layer (type) Output Shape Param # Connected to ======================= ===================== ======== ============================ Attention-1 (None, 3, 768) 2949120 Embed-Token-Masked[0][0] (RelativePartial- Embed-Token-Masked[0][0] MultiHeadAttention Memory-0[0][0] Embed-Segment-1[0][0] Embed-Segment-1[0][1] Embed-Pos[0][0] Relative-Bias-1[0][0] Relative-Bias-1[0][1] Segment-Bias-1[0][0] Permutation[0][0]
keras_xlnet/attention.py
デフォルトで"use_bias"はFalse
パラメータ数は「2,949,120」(=768 * 768 * 5)
Query、Key、Value、Outputに加え、Relative(PositionalEmbeddingレイヤの出力)用のパラメータをもつ。
class RelativePartialMultiHeadSelfAttention(keras.layers.Layer): def __init__(self, ... use_bias=False, ... def build(self, input_shape): # self.units: 768 self.kernel = self.add_weight( shape=(self.units, self.units * 5), ... ) def call(self, inputs, mask=None, training=None): ... kernel_q = self.kernel[:, : self.units] kernel_kv = self.kernel[:, self.units : self.units * 3] kernel_r = self.kernel[:, self.units * 3 : self.units * 4] kernel_o = self.kernel[:, self.units * 4 : self.units * 5] ...
(参考)BERT
デフォルトで"use_bias"はTrue
パラメータ数は「2,362,368」(=(768 * 768 * 4 ) + (768 * 4))
Layer (type) Output Shape Param # Connected to ======================= ===================== ======== ============================ Encoder-1- (None, 3, 768) 2362368 Embedding-Norm[0][0] MultiHeadSelfAttention (MultiHeadAttention)
ソースコードをもとに、推論時のattentionレイヤ内部処理を図示してみる。
(headごとのテンソル分割、統合は省略している)
Queryと重みの内積を計算
# inputs: Embed-Token-Masked[0][0] (batch_size, seq_len, units) # kernel_q : (units, units) w_q = K.dot(inputs, kernel_q) # --> (batch_size, seq_len, units)
Key、Valueと重みの内積を計算
# memories: Memory-0[0][0] (batch_size, prev_len, units) # content: Embed-Token-Masked[0][0] (batch_size, seq_len, units) full = K.concatenate([memories, content], axis=1) # --> (batch_size, prev_len + seq_len, units) # kernel_kv: (units, units * 2) w_kv = K.dot(full, kernel_kv) # --> (batch_size, prev_len + seq_len, units * 2)
PositionalEmbeddingと重みの内積を計算
# relatives: Embed-Pos[0][0] (batch_size, prev_len + seq_len + 1, units) # kernel_r: (units, units) w_r = K.dot(relatives, kernel_r) # --> (batch_size, prev_len + seq_len + 1, units)
QeuryとKeyのdot product
w_k = w_kv[:, :, :self.units] # --> (batch_size, prev_len + seq_len, units) # bias_context: Relative-Bias-*[0][0] w_qc = K.bias_add(w_q, bias_context) # headごとにtensorを分割 w_qc = self._reshape_to_batches(w_qc) # --> (batch_size * n_head, seq_len, units_head) w_k = self._reshape_to_batches(w_k) # --> (batch_size * n_head, prev_len + seq_len, units_head) # QeuryとKeyのdot product a_context = K.batch_dot(w_qc, w_k, axes=2) # --> (batch_size * n_head, seq_len, prev_len + seq_len)
QeuryとPositionalEmbeddingのdot product
# bias_relative: Relative-Bias-*[0][1] w_qr = K.bias_add(w_q, bias_relative) # headごとにtensorを分割 w_qr = self._reshape_to_batches(w_qr) # --> (batch_size * n_head, seq_len, units_head) w_r = self._reshape_to_batches(w_r) # --> (batch_size * n_head, prev_len + seq_len + 1, units_head) # QeuryとPositionalEmbeddingのdot product a_relative = K.batch_dot(w_qr, w_r, axes=2) # --> (batch_size * n_head, seq_len, prev_len + seq_len + 1) # 相対シフト a_relative = self._relative_shift( a_relative, key_len_expected=K.shape(a_context)[-1], ) # --> (batch_size * n_head, seq_len, prev_len + seq_len)
QeuryとRelativeSegmentEmbeddingsのdot product
「RelativeSegmentEmbeddings」レイヤを理解できていないので省略。
出力として「a_segment」が得られる。
3つのdot productで得られた「a_context 」、「a_relative」、 「a_segment」を足し合わせる。
att = (a_context + a_relative + a_segment) / K.sqrt(K.constant(self.units_head, dtype=K.floatx()))
# --> (batch_size * n_head, seq_len, prev_len + seq_len)
この時点の出力の形状を比較
BERT
(batch_size * n_head, seq_len, seq_len)
XLNet
(batch_size * n_head, seq_len, prev_len + seq_len)
次いで、Scale処理
exp = K.exp(att - K.max(att, axis=-1, keepdims=True)) # --> (batch_size * n_head, seq_len, prev_len + seq_len)
「PermutationMask」レイヤの出力を使った計算。
# permutation: Permutation[0][0] (Content mask) # q_len: seq_len # k_len: prev_len + seq_len permutation = K.tile(K.expand_dims(permutation, axis=1), [1, self.num_head, 1, 1]) permutation = K.reshape(permutation, (-1, q_len, k_len)) exp *= permutation # --> (batch_size * n_head, seq_len, prev_len + seq_len)
mask後、softmax
att = exp / (K.sum(exp, axis=-1, keepdims=True) + K.epsilon()) # --> (batch_size * n_head, seq_len, prev_len + seq_len)
Valueに関するdot product
w_v = w_kv[:, :, self.units:] # (units, units) w_v = self._reshape_to_batches(w_v) # (batch_size * n_head, prev_len + seq_len, units_head) w_o = K.batch_dot(att, w_v) # --> (batch_size * n_head, seq_len, units_head)
ここで、テンソルの形状から"prev_len"が消える。
各headを統合する。
w_o = self._reshape_from_batches(w_o)
# --> (batch_size, seq_len, units)
Outputに関するdot product
w_o = K.dot(w_o, kernel_o)
# --> (batch_size, seq_len, units)
これにbiasを加え、activationを通したものがAttentionレイヤの出力になる。
Outputの形状はBERTと同じ。
次回に続く。