Kerasメモ(Memory network)
「bAbI」データセットを使ったMemory networkを試してみる。
論文
データ
Single Supporting Facts
tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_train.txt
1 Mary moved to the bathroom. 2 John went to the hallway. 3 Where is Mary? bathroom 1 ...
読み込み後のデータは3つ組のtupleで構成される。
tupleは「Input」、「Question」、「Answer」の3つを保持。
Input
train_stories[0][0] [u'Mary', u'moved', u'to', u'the', u'bathroom', u'.', u'John', u'went', u'to', u'the', u'hallway', u'.']
Question
train_stories[0][1] [u'Where', u'is', u'Mary', u'?']
Answer
train_stories[0][2] u'bathroom'
データ数はトレーニング用に「10,000」、検証用に「1,000」
全体の語彙数は「21」
vocab [u'.', u'?', u'Daniel', u'John', u'Mary', u'Sandra', u'Where', u'back', u'bathroom', u'bedroom', u'garden', u'hallway', u'is', u'journeyed', u'kitchen', u'moved', u'office', u'the', u'to', u'travelled', u'went']
これをindexで表現する。
word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) word_idx {u'hallway': 12, u'bathroom': 9, u'garden': 11, u'journeyed': 14, u'office': 17, u'is': 13, u'bedroom': 10, u'moved': 16, u'back': 8, u'.': 1, u'to': 19, u'Daniel': 3, u'Sandra': 6, u'travelled': 20, u'went': 21, u'the': 18, u'John': 4, u'Where': 7, u'Mary': 5, u'?': 2, u'kitchen': 15}
各々の最大長は以下のとおり。
Input: 「68」単語
Qustion: 「4」単語
Answer: 「1」単語
教師データとして与えるAnswerはindex(「0」から「21」までの値)
ここで、index=0はdummy
# Reserve 0 for masking via pad_sequences vocab_size = len(vocab) + 1 # 22
モデル
モデルに関するメモ。
Input(Sentences)は、2つのEmbeddingレイヤをとおす。
# vocab_size = 22 # embedding_dim = 64 input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64)) # --> output: (samples, story_maxlen, embedding_dim) # query_maxlen = 4 input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen)) # --> output: (samples, story_maxlen, query_maxlen)
Questionは、1つのEmbeddingレイヤをとおす。
question_encoder.add(Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen)) # --> output: (samples, query_maxlen, embedding_dim)
変換後のInput(Sentences)とQuestionの内積をとって関連ベクトルを算出する。
match = dot([input_encoded_m, question_encoded], axes=(2, 2)) # --> output: (samples, story_maxlen, query_maxlen)