ichou1のブログ

主に音声認識、時々、データ分析のことを書く

BERTメモ(torchsummary)

PyTorch版BERTを使ってみる。
pytorch-pretrained-bert · PyPI

モデルの要約を表示する方法を探したところ、「torchsummary」パッケージが公開されている模様。
torchsummary · PyPI

inputのサイズを指定する必要があり、今回はtokenの長さが「13」であるものとする。

尚、torchsummayの入力は「torch.rand()」を使って「torch.FloatTensor」型の値を生成するため、そのまま使うと落ちる。
今回はBERTの入力に合わせて「torchsummary.py」を書き換えた。

BERT Base

model = BertModel.from_pretrained('bert-base-cased')
        Layer (type)               Output Shape         Param #
================================================================
         Embedding-1              [-1, 13, 768]      22,268,928
         Embedding-2              [-1, 13, 768]         393,216
         Embedding-3              [-1, 13, 768]           1,536
     BertLayerNorm-4              [-1, 13, 768]           1,536
           Dropout-5              [-1, 13, 768]               0
    BertEmbeddings-6              [-1, 13, 768]               0
(参考) Keras-BERT
Layer (type)                      Output Shape         Param #  
================================================================
Input-Token (InputLayer)          (None, 13)                   0
________________________________________________________________
Input-Segment (InputLayer)        (None, 13)                   0
________________________________________________________________
Embedding-Token (TokenEmbedding   [(None, 13, 768),
                                   (28996, 768)]      22,268,928
________________________________________________________________
Embedding-Segment (Embedding)     (None, 13, 768)          1,536
________________________________________________________________
Embedding-Token-Segment (Add)     (None, 13, 768)              0
________________________________________________________________
Embedding-Position (PositionEmb   (None, 13, 768)          9,984
________________________________________________________________
Embedding-Dropout (Dropout)       (None, 13, 768)              0
________________________________________________________________
Embedding-Norm (LayerNormalizat   (None, 13, 768)          1,536

「Embedding-Position」に該当するであろう「Embedding-2」のパラメータ数が一致しない。
PyTorch版は"use sine and cosine functions"のバージョンか?


Layer 1
            Linear-7              [-1, 13, 768]         590,592
            Linear-8              [-1, 13, 768]         590,592
            Linear-9              [-1, 13, 768]         590,592
          Dropout-10           [-1, 12, 13, 13]               0
BertSelfAttention-11              [-1, 13, 768]               0
           Linear-12              [-1, 13, 768]         590,592
          Dropout-13              [-1, 13, 768]               0
    BertLayerNorm-14              [-1, 13, 768]           1,536
   BertSelfOutput-15              [-1, 13, 768]               0
    BertAttention-16              [-1, 13, 768]               0
           Linear-17             [-1, 13, 3072]       2,362,368
 BertIntermediate-18             [-1, 13, 3072]               0
           Linear-19              [-1, 13, 768]       2,360,064
          Dropout-20              [-1, 13, 768]               0
    BertLayerNorm-21              [-1, 13, 768]           1,536
       BertOutput-22              [-1, 13, 768]               0
        BertLayer-23              [-1, 13, 768]               0
(参考) Keras-BERT
Layer (type)                      Output Shape         Param #
================================================================
Encoder-1-MultiHeadSelfAttentio   (None, 13, 768)      2,362,368
________________________________________________________________
Encoder-1-MultiHeadSelfAttentio   (None, 13, 768)              0
________________________________________________________________
Encoder-1-MultiHeadSelfAttentio   (None, 13, 768)              0
________________________________________________________________
Encoder-1-MultiHeadSelfAttentio   (None, 13, 768)          1,536
________________________________________________________________
Encoder-1-FeedForward (FeedForw   (None, 13, 768)      4,722,432
________________________________________________________________
Encoder-1-FeedForward-Dropout (   (None, 13, 768)              0
________________________________________________________________
Encoder-1-FeedForward-Add (Add)   (None, 13, 768)              0
________________________________________________________________
Encoder-1-FeedForward-Norm (Lay   (None, 13, 768)          1,536   
________________________________________________________________

パラメータ数の合計は一致している。

Layer 12
          Linear-194              [-1, 13, 768]         590,592
          Linear-195              [-1, 13, 768]         590,592
          Linear-196              [-1, 13, 768]         590,592
         Dropout-197           [-1, 12, 13, 13]               0
BertSelfAttention-198              [-1, 13, 768]               0
          Linear-199              [-1, 13, 768]         590,592
         Dropout-200              [-1, 13, 768]               0
   BertLayerNorm-201              [-1, 13, 768]           1,536
  BertSelfOutput-202              [-1, 13, 768]               0
   BertAttention-203              [-1, 13, 768]               0
          Linear-204             [-1, 13, 3072]       2,362,368
BertIntermediate-205             [-1, 13, 3072]               0
          Linear-206              [-1, 13, 768]       2,360,064
         Dropout-207              [-1, 13, 768]               0
   BertLayerNorm-208              [-1, 13, 768]           1,536
      BertOutput-209              [-1, 13, 768]               0
       BertLayer-210              [-1, 13, 768]               0
# site-packages/pytorch_pretrained_bert/modeling.py
# Outputs: Tuple of (encoded_layers, pooled_output)
# `encoded_layers`: a list of the full sequences of encoded-hidden-states
#                   at the end of each attention block
# `pooled_output` : the hidden state associated to the first character of the input (`CLS`) 
     BertEncoder-211  [[-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768]]               0
          Linear-212                  [-1, 768]         590,592
            Tanh-213                  [-1, 768]               0
      BertPooler-214                  [-1, 768]               0
================================================================
Total params: 108,310,272
Trainable params: 108,310,272
Non-trainable params: 0

BERT Large

model = BertModel.from_pretrained('bert-large-cased')
        Layer (type)               Output Shape         Param #
================================================================
         Embedding-1             [-1, 13, 1024]      29,691,904
         Embedding-2             [-1, 13, 1024]         524,288
         Embedding-3             [-1, 13, 1024]           2,048
     BertLayerNorm-4             [-1, 13, 1024]           2,048
           Dropout-5             [-1, 13, 1024]               0
    BertEmbeddings-6             [-1, 13, 1024]               0
Layer 1
            Linear-7             [-1, 13, 1024]       1,049,600
            Linear-8             [-1, 13, 1024]       1,049,600
            Linear-9             [-1, 13, 1024]       1,049,600
          Dropout-10           [-1, 16, 13, 13]               0
BertSelfAttention-11             [-1, 13, 1024]               0
           Linear-12             [-1, 13, 1024]       1,049,600
          Dropout-13             [-1, 13, 1024]               0
    BertLayerNorm-14             [-1, 13, 1024]           2,048
   BertSelfOutput-15             [-1, 13, 1024]               0
    BertAttention-16             [-1, 13, 1024]               0
           Linear-17             [-1, 13, 4096]       4,198,400
 BertIntermediate-18             [-1, 13, 4096]               0
           Linear-19             [-1, 13, 1024]       4,195,328
          Dropout-20             [-1, 13, 1024]               0
    BertLayerNorm-21             [-1, 13, 1024]           2,048
       BertOutput-22             [-1, 13, 1024]               0
        BertLayer-23             [-1, 13, 1024]               0
Layer 24
          Linear-398             [-1, 13, 1024]       1,049,600 
          Linear-399             [-1, 13, 1024]       1,049,600 
          Linear-400             [-1, 13, 1024]       1,049,600 
         Dropout-401           [-1, 16, 13, 13]               0 
BertSelfAttention-402            [-1, 13, 1024]               0 
          Linear-403             [-1, 13, 1024]       1,049,600 
         Dropout-404             [-1, 13, 1024]               0 
   BertLayerNorm-405             [-1, 13, 1024]           2,048 
  BertSelfOutput-406             [-1, 13, 1024]               0 
   BertAttention-407             [-1, 13, 1024]               0 
          Linear-408             [-1, 13, 4096]       4,198,400 
BertIntermediate-409             [-1, 13, 4096]               0 
          Linear-410             [-1, 13, 1024]       4,195,328 
         Dropout-411             [-1, 13, 1024]               0 
   BertLayerNorm-412             [-1, 13, 1024]           2,048 
      BertOutput-413             [-1, 13, 1024]               0 
       BertLayer-414             [-1, 13, 1024]               0 
     BertEncoder-415  [[-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024]]               0
          Linear-416                 [-1, 1024]       1,049,600
            Tanh-417                 [-1, 1024]               0
      BertPooler-418                 [-1, 1024]               0
================================================================
Total params: 333,579,264
Trainable params: 333,579,264
Non-trainable params: 0