Kerasメモ(BERTその3)
前々回の続き。
Transformerを構成するMultiHeadAttentionレイヤを見てみる。
MultiHeadAttentionレイヤのインプットの形状が(bathc_size, 512, 768)、「head_num」が「12」である場合、並列化は下図のとおりとなる。
図中の「Wq」、「Wk」、「Wv」、「Wo」はMultiHeadAttentionレイヤ内の重みを表す。
class MultiHeadAttention(keras.layers.Layer): def __init__(self, head_num, activation='relu', use_bias=True, kernel_initializer='glorot_normal', bias_initializer='zeros', # <snip> **kwargs): # <snip> self.Wq, self.Wk, self.Wv, self.Wo = None, None, None, None self.bq, self.bk, self.bv, self.bo = None, None, None, None
model.summary
Layer (type) Output Shape Param # ============================================================================== Encoder-MultiHeadSelfAttention (MultiHeadAttention) (None, 512, 768) 2362368 ============================================================================== # Weight : (768, 768) * 4 # Bias : (768, ) * 4
重みは、Q(query)、K(key)、V(value)、O(Concat後のLinear)で、それぞれ持つ。
class MultiHeadAttention(keras.layers.Layer): def build(self, input_shape): # isinstance(input_shape, list) --> False # type(input_shape) --> tuple q = k = v = input_shape self.Wq = self.add_weight( shape=(int(q[-1]), feature_dim), ...) self.Wk = self.add_weight( shape=(int(k[-1]), feature_dim), ...) self.Wv = self.add_weight( shape=(int(v[-1]), feature_dim), ...) self.Wo = self.add_weight( shape=(feature_dim, feature_dim), ...)
レイヤロジック開始の段階で、Q(query)、K(key)、V(value)は同じ内容になる。
class MultiHeadAttention(keras.layers.Layer): def call(self, inputs, mask=None): # isinstance(inputs, list) --> False # type(inputs) --> <class 'tensorflow.python.framework.ops.Tensor'> q = k = v = inputs
入力と重みの内積を計算し、
q = K.dot(q, self.Wq) k = K.dot(k, self.Wk) v = K.dot(v, self.Wv)
テンソルを分割してScaledDotProductAttentionレイヤに渡す。
y = ScaledDotProductAttention(...)( inputs=[ self._reshape_to_batches(q, self.head_num), self._reshape_to_batches(k, self.head_num), self._reshape_to_batches(v, self.head_num), ], mask=[ self._reshape_mask(q_mask, self.head_num), self._reshape_mask(k_mask, self.head_num), self._reshape_mask(v_mask, self.head_num), ], )
テンソルの分割方法。
def _reshape_to_batches(x, head_num): input_shape = K.shape(x) batch_size = input_shape[0] seq_len = input_shape[1] feature_dim = input_shape[2] head_dim = feature_dim // head_num # 64 = 768 // 12 x = K.reshape(x, (batch_size, seq_len, head_num, head_dim)) # --> (batch_size, 512, 12, 64) x = K.permute_dimensions(x, [0, 2, 1, 3]) # --> (batch_size, 12, 512, 64) return K.reshape(x, (batch_size * head_num, seq_len, head_dim)) # --> (batch_size*12, 512, 64)
以降、「ScaledDotProductAttention」レイヤ。
「keras_self_attention」パッケージに属するようだが、PyPIの説明には記載がない。
後から追加されたことがgithub上で確認できる。
このレイヤはbuildメソッドが無い。すなわち学習する重みを持たない(ソース全体)
class ScaledDotProductAttention(keras.layers.Layer): def call(self, inputs, mask=None, **kwargs): # isinstance(inputs, list) --> True query, key, value = inputs
以降、冒頭の図’の左側の通り処理。
Matmul
e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx())) # e.shape --> (None, 512, 512)
Scale
「0」から「1」の範囲に収める。
e = K.exp(e - K.max(e, axis=-1, keepdims=True))
Mask
捨てるsentenceは「0」を掛ける。
# mask.shape --> (None, 512) # K.expand_dims(mask, axis=-2) --> (None, 1, 512) e *= K.cast(K.expand_dims(mask, axis=-2), K.floatx())
Softmax
# e.shape --> (None, 512, 512) a = e / (K.sum(e, axis=-1, keepdims=True) + K.epsilon()) # a : Attention weight # a.shape --> (None, 512, 512)
Matmul
# a.shape --> (None, 512, 512) # value.shape --> (None, 512, 64) v = K.batch_dot(a, value) # v.shape --> (None, 512, 64)
この"v"がScaledDotProductAttentionレイヤのOutputになる。
Attention Weightを可視化したい時は"a"を使う。
MultiHeadAttentionレイヤに戻る。
各headを統合後、重みWoを掛け、バイアスboを加算する。
multi_head_attention.py
class MultiHeadAttention(keras.layers.Layer): ... def call(self, inputs, mask=None): ... y = ScaledDotProductAttention(...)(...) y = self._reshape_from_batches(y, self.head_num) y = K.dot(y, self.Wo) if self.use_bias: y += self.bo if self.activation is not None: y = self.activation(y) return y # y.shape(None, 512, 768)