Kerasメモ(XLNet)その1
XLNetのKeras実装を試してみる。
load_trained_model_from_checkpoint関数でpre-trainedモデルをロードする。
BaseとLargeの両方のモデルに対応。
モデルの構成は、"in_train_phase"パラメータによって変わる。
in_train_phase=Falseの場合、Outputは各tokenごとの特徴量。
The feature for each token, with shape (batch_size, target_len, units).
in_train_phase=Trueの場合、Outputは各tokenごとの、辞書の各語彙に対する確率。
The probability of each token in each position, with shape (batch_size, target_len, num_token).
Inputも、Falseのときと比べて1つ増える。
Masks of tokens, with shape (batch_size, target_len).
これは、本家data_utils.pyの"_sample_mask"関数の返り値に該当する部分と思われる。
Demoを見ると、全てin_train_phase=Falseの例なので、まずは、in_train_phase=Falseの場合を確認してみる。
(Demoには、fine-tuningによるGLUEタスクのトレーニングが提供されている)
keras-xlnet/demo at master · CyberZHG/keras-xlnet · GitHub
特徴量抽出(extract)の例から見ていく。
https://github.com/CyberZHG/keras-xlnet/tree/master/demo/extract
下記パラメータ
- memory_len (Maximum memory length)
- target_len (Maximum length of targets)
を変えた2つのパターンがある。
インプットとなる文。
text = "All's right with the world"
そのencode結果(token length = 7)
['▁All', "'", 's', '▁right', '▁with', '▁the', '▁world']
token_embeddings.py(memory_len=0, target_len=7)
model = load_trained_model_from_checkpoint( config_path=config_path, checkpoint_path=model_path, batch_size=1, memory_len=0, # <-- target_len=7, # <-- in_train_phase=False, attention_type=ATTENTION_TYPE_UNI, )
token_embeddings_with_memory.py(memory_len=7, target_len=3)
model = load_trained_model_from_checkpoint( config_path=config_path, checkpoint_path=model_path, batch_size=1, memory_len=7, # <-- target_len=3, # <-- in_train_phase=False, attention_type=ATTENTION_TYPE_UNI, )
DemoのGLUEタスクトレーニングを見てみると、sequeceの長さが"target_len"におさまるように設定している。
一番長くて、STS-Bの「140」
cola.py
SEQ_LEN = 32 model = load_trained_model_from_checkpoint( config_path=paths.config, checkpoint_path=paths.model, batch_size=BATCH_SIZE, memory_len=0, target_len=SEQ_LEN, in_train_phase=False, attention_type='bi', )
"memory_len"が非0のパターン(「token_embeddings_with_memory.py」)のmodelのsummaryを表示してみる。
model.summary(Input)
Layer (type) Output Shape Param # Connected to ======================= ===================== ======== ============================ Input-Token (None, 3) 0 (InputLayer) _______________________ _____________________ ________ ____________________________ Embed-Token [(None, 3, 768), 24576000 Input-Token[0][0] (EmbeddingRet) (32000, 768)] _______________________ _____________________ ________ ____________________________ Masking (None, 3) 0 Input-Token[0][0] (CreateMask) _______________________ _____________________ ________ ____________________________ Embed-Token-Masked (None, 3, 768) 0 Embed-Token[0][0] (RestoreMask) Masking[0][0] _______________________ _____________________ ________ ____________________________ Input-Memory-Length (None, 1) 0 (InputLayer) _______________________ _____________________ ________ ____________________________ Memory-0 (None, None, 768) 7680 Embed-Token-Masked[0][0] (Memory) Input-Memory-Length[0][0] _______________________ _____________________ ________ ____________________________ Input-Segment (None, 3) 0 (InputLayer) _______________________ _____________________ ________ ____________________________ Embed-Pos (None, None, 768) 0 Embed-Token-Masked[0][0] (PositionalEmbedding) Memory-0[0][0] _______________________ _____________________ ________ ____________________________ Permutation [(None, 3, None), 0 Embed-Token-Masked[0][0] (PermutationMask) (None, 3, None)] Memory-0[0][0] _______________________ _____________________ ________ ____________________________
Permutation[0][0] : Content mask
Permutation[0][1] : Query mask
model.summary(Layer 1)
Layer (type) Output Shape Param # Connected to ======================= ===================== ======== ============================ Embed-Segment-1 [(None, 3, None, 2), 1536 Input-Segment[0][0] (Relative- (2, 768)] Memory-0[0][0] SegmentEmbeddings) _______________________ _____________________ ________ ____________________________ Relative-Bias-1 [(768,), (768,)] 1536 Memory-0[0][0] (RelativeBias) _______________________ _____________________ ________ ____________________________ Segment-Bias-1 (768,) 768 Memory-0[0][0] (SegmentBias) _______________________ _____________________ ________ ____________________________ Attention-1 (None, 3, 768) 2949120 Embed-Token-Masked[0][0] (RelativePartial- Embed-Token-Masked[0][0] MultiHeadAttention) Memory-0[0][0] Embed-Segment-1[0][0] Embed-Segment-1[0][1] Embed-Pos[0][0] Relative-Bias-1[0][0] Relative-Bias-1[0][1] Segment-Bias-1[0][0] Permutation[0][0] _______________________ _____________________ ________ ____________________________ Attention-Residual-1 (None, 3, 768) 0 Embed-Token-Masked[0][0] (Add) Attention-1[0][0] _______________________ _____________________ ________ ____________________________ Attention-Normal-1 (None, 3, 768) 1536 Attention-Residual-1[0][0] (LayerNormalization) _______________________ _____________________ ________ ____________________________ FeedForward-1 (None, 3, 768) 4722432 Attention-Normal-1[0][0] (FeedForward) _______________________ _____________________ ________ ____________________________ FeedForward-Residual-1 (None, 3, 768) 0 Attention-Normal-1[0][0] (Add) FeedForward-1[0][0] _______________________ _____________________ ________ ____________________________ FeedForward-Normal-1 (None, 3, 768) 1536 FeedForward-Residual-1[0][0] (LayerNormalization) _______________________ _____________________ ________ ____________________________ Memory-1 (None, None, 768) 7680 FeedForward-Normal-1[0][0] (Memory) Input-Memory-Length[0][0] _______________________ _____________________ ________ ____________________________
model.summary(Layer 12)
Layer (type) Output Shape Param # Connected to ======================= ===================== ======== ============================ Embed-Segment-12 [(None, 3, None, 2), 1536 Input-Segment[0][0] (Relative- (2, 768)] Memory-11[0][0] SegmentEmbeddings) _______________________ _____________________ ________ ____________________________ Relative-Bias-12 [(768,), (768,)] 1536 Memory-11[0][0] (RelativeBias) _______________________ _____________________ ________ ____________________________ Segment-Bias-12 (768,) 768 Memory-11[0][0] (SegmentBias) _______________________ _____________________ ________ ____________________________ Attention-12 (None, 3, 768) 2949120 FeedForward-Normal-11[0][0] (RelativePartial- FeedForward-Normal-11[0][0] MultiHeadAttention Memory-11[0][0] Embed-Segment-12[0][0] Embed-Segment-12[0][1] Embed-Pos[0][0] Relative-Bias-12[0][0] Relative-Bias-12[0][1] Segment-Bias-12[0][0] Permutation[0][0] _______________________ _____________________ ________ ____________________________ Attention-Residual-12 (None, 3, 768) 0 FeedForward-Normal-11[0][0] (Add) Attention-12[0][0] _______________________ _____________________ ________ ____________________________ Attention-Normal-12 (None, 3, 768) 1536 Attention-Residual-12[0][0] (LayerNormalization) _______________________ _____________________ ________ ____________________________ FeedForward-12 (None, 3, 768) 4722432 Attention-Normal-12[0][0] (FeedForward) _______________________ _____________________ ________ ____________________________ FeedForward-Residual-12 (None, 3, 768) 0 Attention-Normal-12[0][0] (Add) FeedForward-12[0][0] _______________________ _____________________ ________ ____________________________ FeedForward-Normal-12 (None, 3, 768) 1536 FeedForward-Residual-12[0][0] (LayerNormalization)
次回に続く。