ichou1のブログ

主に音声認識、時々、データ分析のことを書く

Kerasメモ(BERTその3)

前々回の続き。

Transformerを構成するMultiHeadAttentionレイヤを見てみる。

MultiHeadAttentionレイヤのインプットの形状が(bathc_size, 512, 768)、「head_num」が「12」である場合、並列化は下図のとおりとなる。
f:id:ichou1:20190614200419p:plain

図中の「Wq」、「Wk」、「Wv」、「Wo」はMultiHeadAttentionレイヤ内の重みを表す。

class MultiHeadAttention(keras.layers.Layer):

    def __init__(self,
                 head_num,
                 activation='relu',
                 use_bias=True,
                 kernel_initializer='glorot_normal',
                 bias_initializer='zeros',
		# <snip>
                 **kwargs):
		# <snip>
        self.Wq, self.Wk, self.Wv, self.Wo = None, None, None, None
        self.bq, self.bk, self.bv, self.bo = None, None, None, None

ソース全体


model.summary
Layer (type)                                         Output Shape      Param #
==============================================================================
Encoder-MultiHeadSelfAttention (MultiHeadAttention)  (None, 512, 768)  2362368
==============================================================================
# Weight : (768, 768) * 4
# Bias   : (768, ) * 4

重みは、Q(query)、K(key)、V(value)、O(Concat後のLinear)で、それぞれ持つ。

class MultiHeadAttention(keras.layers.Layer):

    def build(self, input_shape):
        # isinstance(input_shape, list) --> False
        # type(input_shape) --> tuple
        q = k = v = input_shape

        self.Wq = self.add_weight( shape=(int(q[-1]), feature_dim), ...)
        self.Wk = self.add_weight( shape=(int(k[-1]), feature_dim), ...)
        self.Wv = self.add_weight( shape=(int(v[-1]), feature_dim), ...)
        self.Wo = self.add_weight( shape=(feature_dim, feature_dim), ...)

レイヤロジック開始の段階で、Q(query)、K(key)、V(value)は同じ内容になる。

class MultiHeadAttention(keras.layers.Layer):

    def call(self, inputs, mask=None):
        # isinstance(inputs, list) --> False
        # type(inputs) --> <class 'tensorflow.python.framework.ops.Tensor'>
        q = k = v = inputs

入力と重みの内積を計算し、

        q = K.dot(q, self.Wq)
        k = K.dot(k, self.Wk)
        v = K.dot(v, self.Wv)

テンソルを分割してScaledDotProductAttentionレイヤに渡す。

        y = ScaledDotProductAttention(...)(
            inputs=[
                self._reshape_to_batches(q, self.head_num),
                self._reshape_to_batches(k, self.head_num),
                self._reshape_to_batches(v, self.head_num),
            ],
            mask=[
                self._reshape_mask(q_mask, self.head_num),
                self._reshape_mask(k_mask, self.head_num),
                self._reshape_mask(v_mask, self.head_num),
            ],
            )

テンソルの分割方法。

def _reshape_to_batches(x, head_num):
    input_shape = K.shape(x)

    batch_size = input_shape[0]
    seq_len = input_shape[1]
    feature_dim = input_shape[2]

    head_dim = feature_dim // head_num	# 64 = 768 // 12 
    x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
    # --> (batch_size, 512, 12, 64)
    x = K.permute_dimensions(x, [0, 2, 1, 3])
    # --> (batch_size, 12, 512, 64)
    return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
    # --> (batch_size*12, 512, 64)


以降、「ScaledDotProductAttention」レイヤ。
「keras_self_attention」パッケージに属するようだが、PyPIの説明には記載がない。
後から追加されたことがgithub上で確認できる。


このレイヤはbuildメソッドが無い。すなわち学習する重みを持たない(ソース全体

class ScaledDotProductAttention(keras.layers.Layer):

    def call(self, inputs, mask=None, **kwargs):
        # isinstance(inputs, list) --> True
        query, key, value = inputs

以降、冒頭の図’の左側の通り処理。

Matmul



\frac{ Q \ K^T}{ \sqrt{ d_k } }

        e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))
        # e.shape --> (None, 512, 512)
Scale

「0」から「1」の範囲に収める。

        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
Mask

捨てるsentenceは「0」を掛ける。

        # mask.shape --> (None, 512)
        # K.expand_dims(mask, axis=-2) --> (None, 1, 512)
        e *= K.cast(K.expand_dims(mask, axis=-2), K.floatx())
Softmax
        # e.shape --> (None, 512, 512)
        a = e / (K.sum(e, axis=-1, keepdims=True) + K.epsilon())
        # a : Attention weight
        # a.shape --> (None, 512, 512)
Matmul
        # a.shape --> (None, 512, 512)
        # value.shape --> (None, 512, 64)
        v = K.batch_dot(a, value)
        # v.shape --> (None, 512, 64)

この"v"がScaledDotProductAttentionレイヤのOutputになる。
Attention Weightを可視化したい時は"a"を使う。



MultiHeadAttentionレイヤに戻る。
各headを統合後、重みWoを掛け、バイアスboを加算する。

multi_head_attention.py
class MultiHeadAttention(keras.layers.Layer):
    ...
    def call(self, inputs, mask=None):
        ...
        y = ScaledDotProductAttention(...)(...)
        y = self._reshape_from_batches(y, self.head_num)
        y = K.dot(y, self.Wo)
        if self.use_bias:
            y += self.bo
        if self.activation is not None:
            y = self.activation(y)
        return y  # y.shape(None, 512, 768)