BERTメモ(torchsummary)
PyTorch版BERTを使ってみる。
pytorch-pretrained-bert · PyPI
モデルの要約を表示する方法を探したところ、「torchsummary」パッケージが公開されている模様。
torchsummary · PyPI
inputのサイズを指定する必要があり、今回はtokenの長さが「13」であるものとする。
尚、torchsummayの入力は「torch.rand()」を使って「torch.FloatTensor」型の値を生成するため、そのまま使うと落ちる。
今回はBERTの入力に合わせて「torchsummary.py」を書き換えた。
BERT Base
model = BertModel.from_pretrained('bert-base-cased')
Layer (type) Output Shape Param # ================================================================ Embedding-1 [-1, 13, 768] 22,268,928 Embedding-2 [-1, 13, 768] 393,216 Embedding-3 [-1, 13, 768] 1,536 BertLayerNorm-4 [-1, 13, 768] 1,536 Dropout-5 [-1, 13, 768] 0 BertEmbeddings-6 [-1, 13, 768] 0
(参考) Keras-BERT
Layer (type) Output Shape Param # ================================================================ Input-Token (InputLayer) (None, 13) 0 ________________________________________________________________ Input-Segment (InputLayer) (None, 13) 0 ________________________________________________________________ Embedding-Token (TokenEmbedding [(None, 13, 768), (28996, 768)] 22,268,928 ________________________________________________________________ Embedding-Segment (Embedding) (None, 13, 768) 1,536 ________________________________________________________________ Embedding-Token-Segment (Add) (None, 13, 768) 0 ________________________________________________________________ Embedding-Position (PositionEmb (None, 13, 768) 9,984 ________________________________________________________________ Embedding-Dropout (Dropout) (None, 13, 768) 0 ________________________________________________________________ Embedding-Norm (LayerNormalizat (None, 13, 768) 1,536
「Embedding-Position」に該当するであろう「Embedding-2」のパラメータ数が一致しない。
PyTorch版は"use sine and cosine functions"のバージョンか?
Layer 1
Linear-7 [-1, 13, 768] 590,592 Linear-8 [-1, 13, 768] 590,592 Linear-9 [-1, 13, 768] 590,592 Dropout-10 [-1, 12, 13, 13] 0 BertSelfAttention-11 [-1, 13, 768] 0 Linear-12 [-1, 13, 768] 590,592 Dropout-13 [-1, 13, 768] 0 BertLayerNorm-14 [-1, 13, 768] 1,536 BertSelfOutput-15 [-1, 13, 768] 0 BertAttention-16 [-1, 13, 768] 0 Linear-17 [-1, 13, 3072] 2,362,368 BertIntermediate-18 [-1, 13, 3072] 0 Linear-19 [-1, 13, 768] 2,360,064 Dropout-20 [-1, 13, 768] 0 BertLayerNorm-21 [-1, 13, 768] 1,536 BertOutput-22 [-1, 13, 768] 0 BertLayer-23 [-1, 13, 768] 0
(参考) Keras-BERT
Layer (type) Output Shape Param # ================================================================ Encoder-1-MultiHeadSelfAttentio (None, 13, 768) 2,362,368 ________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 13, 768) 0 ________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 13, 768) 0 ________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 13, 768) 1,536 ________________________________________________________________ Encoder-1-FeedForward (FeedForw (None, 13, 768) 4,722,432 ________________________________________________________________ Encoder-1-FeedForward-Dropout ( (None, 13, 768) 0 ________________________________________________________________ Encoder-1-FeedForward-Add (Add) (None, 13, 768) 0 ________________________________________________________________ Encoder-1-FeedForward-Norm (Lay (None, 13, 768) 1,536 ________________________________________________________________
パラメータ数の合計は一致している。
Layer 12
Linear-194 [-1, 13, 768] 590,592 Linear-195 [-1, 13, 768] 590,592 Linear-196 [-1, 13, 768] 590,592 Dropout-197 [-1, 12, 13, 13] 0 BertSelfAttention-198 [-1, 13, 768] 0 Linear-199 [-1, 13, 768] 590,592 Dropout-200 [-1, 13, 768] 0 BertLayerNorm-201 [-1, 13, 768] 1,536 BertSelfOutput-202 [-1, 13, 768] 0 BertAttention-203 [-1, 13, 768] 0 Linear-204 [-1, 13, 3072] 2,362,368 BertIntermediate-205 [-1, 13, 3072] 0 Linear-206 [-1, 13, 768] 2,360,064 Dropout-207 [-1, 13, 768] 0 BertLayerNorm-208 [-1, 13, 768] 1,536 BertOutput-209 [-1, 13, 768] 0 BertLayer-210 [-1, 13, 768] 0
# site-packages/pytorch_pretrained_bert/modeling.py # Outputs: Tuple of (encoded_layers, pooled_output) # `encoded_layers`: a list of the full sequences of encoded-hidden-states # at the end of each attention block # `pooled_output` : the hidden state associated to the first character of the input (`CLS`)
BertEncoder-211 [[-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768], [-1, 13, 768]] 0 Linear-212 [-1, 768] 590,592 Tanh-213 [-1, 768] 0 BertPooler-214 [-1, 768] 0 ================================================================ Total params: 108,310,272 Trainable params: 108,310,272 Non-trainable params: 0
BERT Large
model = BertModel.from_pretrained('bert-large-cased')
Layer (type) Output Shape Param # ================================================================ Embedding-1 [-1, 13, 1024] 29,691,904 Embedding-2 [-1, 13, 1024] 524,288 Embedding-3 [-1, 13, 1024] 2,048 BertLayerNorm-4 [-1, 13, 1024] 2,048 Dropout-5 [-1, 13, 1024] 0 BertEmbeddings-6 [-1, 13, 1024] 0
Layer 1
Linear-7 [-1, 13, 1024] 1,049,600 Linear-8 [-1, 13, 1024] 1,049,600 Linear-9 [-1, 13, 1024] 1,049,600 Dropout-10 [-1, 16, 13, 13] 0 BertSelfAttention-11 [-1, 13, 1024] 0 Linear-12 [-1, 13, 1024] 1,049,600 Dropout-13 [-1, 13, 1024] 0 BertLayerNorm-14 [-1, 13, 1024] 2,048 BertSelfOutput-15 [-1, 13, 1024] 0 BertAttention-16 [-1, 13, 1024] 0 Linear-17 [-1, 13, 4096] 4,198,400 BertIntermediate-18 [-1, 13, 4096] 0 Linear-19 [-1, 13, 1024] 4,195,328 Dropout-20 [-1, 13, 1024] 0 BertLayerNorm-21 [-1, 13, 1024] 2,048 BertOutput-22 [-1, 13, 1024] 0 BertLayer-23 [-1, 13, 1024] 0
Layer 24
Linear-398 [-1, 13, 1024] 1,049,600 Linear-399 [-1, 13, 1024] 1,049,600 Linear-400 [-1, 13, 1024] 1,049,600 Dropout-401 [-1, 16, 13, 13] 0 BertSelfAttention-402 [-1, 13, 1024] 0 Linear-403 [-1, 13, 1024] 1,049,600 Dropout-404 [-1, 13, 1024] 0 BertLayerNorm-405 [-1, 13, 1024] 2,048 BertSelfOutput-406 [-1, 13, 1024] 0 BertAttention-407 [-1, 13, 1024] 0 Linear-408 [-1, 13, 4096] 4,198,400 BertIntermediate-409 [-1, 13, 4096] 0 Linear-410 [-1, 13, 1024] 4,195,328 Dropout-411 [-1, 13, 1024] 0 BertLayerNorm-412 [-1, 13, 1024] 2,048 BertOutput-413 [-1, 13, 1024] 0 BertLayer-414 [-1, 13, 1024] 0
BertEncoder-415 [[-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024], [-1, 13, 1024]] 0 Linear-416 [-1, 1024] 1,049,600 Tanh-417 [-1, 1024] 0 BertPooler-418 [-1, 1024] 0 ================================================================ Total params: 333,579,264 Trainable params: 333,579,264 Non-trainable params: 0