ichou1のブログ

主に音声認識、時々、データ分析のことを書く

Kerasメモ(XLNet)その1

XLNetのKeras実装を試してみる。

keras-xlnet · PyPI

load_trained_model_from_checkpoint関数でpre-trainedモデルをロードする。
BaseとLargeの両方のモデルに対応。

モデルの構成は、"in_train_phase"パラメータによって変わる。

in_train_phase=Falseの場合、Outputは各tokenごとの特徴量。

The feature for each token, with shape (batch_size, target_len, units).

in_train_phase=Trueの場合、Outputは各tokenごとの、辞書の各語彙に対する確率。

The probability of each token in each position, with shape (batch_size, target_len, num_token).

Inputも、Falseのときと比べて1つ増える。

Masks of tokens, with shape (batch_size, target_len).

これは、本家data_utils.pyの"_sample_mask"関数の返り値に該当する部分と思われる。



Demoを見ると、全てin_train_phase=Falseの例なので、まずは、in_train_phase=Falseの場合を確認してみる。
(Demoには、fine-tuningによるGLUEタスクのトレーニングが提供されている)
keras-xlnet/demo at master · CyberZHG/keras-xlnet · GitHub

特徴量抽出(extract)の例から見ていく。
https://github.com/CyberZHG/keras-xlnet/tree/master/demo/extract

下記パラメータ

  • memory_len (Maximum memory length)
  • target_len (Maximum length of targets)

を変えた2つのパターンがある。

インプットとなる文。

text = "All's right with the world"

そのencode結果(token length = 7)

['▁All', "'", 's', '▁right', '▁with', '▁the', '▁world']
token_embeddings.py(memory_len=0, target_len=7)
model = load_trained_model_from_checkpoint(
    config_path=config_path,
    checkpoint_path=model_path,
    batch_size=1,
    memory_len=0,  # <--
    target_len=7,  # <--
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_UNI,
)
token_embeddings_with_memory.py(memory_len=7, target_len=3)
model = load_trained_model_from_checkpoint(
    config_path=config_path,
    checkpoint_path=model_path,
    batch_size=1,
    memory_len=7,  # <--
    target_len=3,  # <--
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_UNI,
)

DemoのGLUEタスクトレーニングを見てみると、sequeceの長さが"target_len"におさまるように設定している。
一番長くて、STS-Bの「140」

cola.py
SEQ_LEN = 32
model = load_trained_model_from_checkpoint(
    config_path=paths.config,
    checkpoint_path=paths.model,
    batch_size=BATCH_SIZE,
    memory_len=0,
    target_len=SEQ_LEN,
    in_train_phase=False,
    attention_type='bi',
)



"memory_len"が非0のパターン(「token_embeddings_with_memory.py」)のmodelのsummaryを表示してみる。

model.summary(Input)
Layer (type)            Output Shape          Param #  Connected to                     
======================= ===================== ======== ============================
Input-Token             (None, 3)             0                                    
(InputLayer)                                           
_______________________ _____________________ ________ ____________________________
Embed-Token             [(None, 3, 768),      24576000 Input-Token[0][0]           
(EmbeddingRet)           (32000, 768)]                 
_______________________ _____________________ ________ ____________________________
Masking                 (None, 3)             0        Input-Token[0][0]           
(CreateMask)                                           
_______________________ _____________________ ________ ____________________________
Embed-Token-Masked      (None, 3, 768)        0        Embed-Token[0][0]           
(RestoreMask)                                          Masking[0][0]               
_______________________ _____________________ ________ ____________________________
Input-Memory-Length     (None, 1)             0                                    
(InputLayer)                                           
_______________________ _____________________ ________ ____________________________
Memory-0                (None, None, 768)     7680     Embed-Token-Masked[0][0]    
(Memory)                                               Input-Memory-Length[0][0]   
_______________________ _____________________ ________ ____________________________
Input-Segment           (None, 3)             0                                    
(InputLayer)                                           
_______________________ _____________________ ________ ____________________________
Embed-Pos               (None, None, 768)     0        Embed-Token-Masked[0][0]    
(PositionalEmbedding)                                  Memory-0[0][0]              
_______________________ _____________________ ________ ____________________________
Permutation             [(None, 3, None),     0        Embed-Token-Masked[0][0]    
(PermutationMask)        (None, 3, None)]              Memory-0[0][0]              
_______________________ _____________________ ________ ____________________________

Permutation[0][0] : Content mask
Permutation[0][1] : Query mask

model.summary(Layer 1)
Layer (type)            Output Shape          Param #  Connected to                
======================= ===================== ======== ============================
Embed-Segment-1         [(None, 3, None, 2),  1536     Input-Segment[0][0]         
(Relative-               (2, 768)]                     Memory-0[0][0]              
 SegmentEmbeddings)
_______________________ _____________________ ________ ____________________________
Relative-Bias-1         [(768,), (768,)]      1536     Memory-0[0][0]              
(RelativeBias)                                         
_______________________ _____________________ ________ ____________________________
Segment-Bias-1          (768,)                768      Memory-0[0][0]              
(SegmentBias)                                          
_______________________ _____________________ ________ ____________________________
Attention-1             (None, 3, 768)        2949120  Embed-Token-Masked[0][0]    
(RelativePartial-                                      Embed-Token-Masked[0][0]    
 MultiHeadAttention)                                   Memory-0[0][0]              
                                                       Embed-Segment-1[0][0]       
                                                       Embed-Segment-1[0][1]       
                                                       Embed-Pos[0][0]             
                                                       Relative-Bias-1[0][0]       
                                                       Relative-Bias-1[0][1]       
                                                       Segment-Bias-1[0][0]        
                                                       Permutation[0][0]           
_______________________ _____________________ ________ ____________________________
Attention-Residual-1    (None, 3, 768)        0        Embed-Token-Masked[0][0]    
(Add)                                                  Attention-1[0][0]           
_______________________ _____________________ ________ ____________________________
Attention-Normal-1      (None, 3, 768)        1536     Attention-Residual-1[0][0]  
(LayerNormalization)                                   
_______________________ _____________________ ________ ____________________________
FeedForward-1           (None, 3, 768)        4722432  Attention-Normal-1[0][0]    
(FeedForward)                                          
_______________________ _____________________ ________ ____________________________
FeedForward-Residual-1  (None, 3, 768)        0        Attention-Normal-1[0][0]    
(Add)                                                  FeedForward-1[0][0]         
_______________________ _____________________ ________ ____________________________
FeedForward-Normal-1    (None, 3, 768)        1536     FeedForward-Residual-1[0][0]
(LayerNormalization)                                   
_______________________ _____________________ ________ ____________________________
Memory-1                (None, None, 768)     7680     FeedForward-Normal-1[0][0]  
(Memory)                                               Input-Memory-Length[0][0]   
_______________________ _____________________ ________ ____________________________
model.summary(Layer 12)
Layer (type)            Output Shape          Param #  Connected to                
======================= ===================== ======== ============================
Embed-Segment-12        [(None, 3, None, 2),  1536     Input-Segment[0][0]         
(Relative-               (2, 768)]                     Memory-11[0][0]              
 SegmentEmbeddings)
_______________________ _____________________ ________ ____________________________
Relative-Bias-12        [(768,), (768,)]      1536     Memory-11[0][0]              
(RelativeBias)                                         
_______________________ _____________________ ________ ____________________________
Segment-Bias-12         (768,)                768      Memory-11[0][0]              
(SegmentBias)                                          
_______________________ _____________________ ________ ____________________________
Attention-12            (None, 3, 768)        2949120  FeedForward-Normal-11[0][0]
(RelativePartial-                                      FeedForward-Normal-11[0][0]
 MultiHeadAttention                                    Memory-11[0][0]              
                                                       Embed-Segment-12[0][0]       
                                                       Embed-Segment-12[0][1]       
                                                       Embed-Pos[0][0]             
                                                       Relative-Bias-12[0][0]       
                                                       Relative-Bias-12[0][1]       
                                                       Segment-Bias-12[0][0]        
                                                       Permutation[0][0]           
_______________________ _____________________ ________ ____________________________
Attention-Residual-12   (None, 3, 768)        0        FeedForward-Normal-11[0][0]
(Add)                                                  Attention-12[0][0]           
_______________________ _____________________ ________ ____________________________
Attention-Normal-12     (None, 3, 768)        1536     Attention-Residual-12[0][0]  
(LayerNormalization)                                   
_______________________ _____________________ ________ ____________________________
FeedForward-12          (None, 3, 768)        4722432  Attention-Normal-12[0][0]    
(FeedForward)                                          
_______________________ _____________________ ________ ____________________________
FeedForward-Residual-12 (None, 3, 768)        0        Attention-Normal-12[0][0]    
(Add)                                                  FeedForward-12[0][0]         
_______________________ _____________________ ________ ____________________________
FeedForward-Normal-12   (None, 3, 768)        1536     FeedForward-Residual-12[0][0]
(LayerNormalization)                                   

次回に続く。