ichou1のブログ

主に音声認識、時々、データ分析のことを書く

Kerasメモ(Memory network)

「bAbI」データセットを使ったMemory networkを試してみる。

論文

arxiv.org

データ

Single Supporting Facts
tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_train.txt

1 Mary moved to the bathroom.
2 John went to the hallway.
3 Where is Mary? 	bathroom	1
...

読み込み後のデータは3つ組のtupleで構成される。
tupleは「Input」、「Question」、「Answer」の3つを保持。

Input
train_stories[0][0]
[u'Mary', u'moved', u'to', u'the', u'bathroom', u'.', u'John', u'went', u'to', u'the', u'hallway', u'.']
Question
train_stories[0][1]
[u'Where', u'is', u'Mary', u'?']
Answer
train_stories[0][2]
u'bathroom'

データ数はトレーニング用に「10,000」、検証用に「1,000」

全体の語彙数は「21」

vocab
[u'.', u'?', u'Daniel', u'John', u'Mary', u'Sandra', u'Where', u'back', u'bathroom', u'bedroom', u'garden', u'hallway', u'is', u'journeyed', u'kitchen', u'moved', u'office', u'the', u'to', u'travelled', u'went']

これをindexで表現する。

word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
word_idx
{u'hallway': 12, u'bathroom': 9, u'garden': 11, u'journeyed': 14, u'office': 17, u'is': 13, u'bedroom': 10, u'moved': 16, u'back': 8, u'.': 1, u'to': 19, u'Daniel': 3, u'Sandra': 6, u'travelled': 20, u'went': 21, u'the': 18, u'John': 4, u'Where': 7, u'Mary': 5, u'?': 2, u'kitchen': 15}

各々の最大長は以下のとおり。
Input: 「68」単語
Qustion: 「4」単語
Answer: 「1」単語

教師データとして与えるAnswerはindex(「0」から「21」までの値)
ここで、index=0はdummy

# Reserve 0 for masking via pad_sequences
vocab_size = len(vocab) + 1    # 22
モデル

f:id:ichou1:20190323131501p:plain

モデルに関するメモ。
Input(Sentences)は、2つのEmbeddingレイヤをとおす。

# vocab_size = 22
# embedding_dim = 64
input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
# --> output: (samples, story_maxlen, embedding_dim)

# query_maxlen = 4
input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen))
# --> output: (samples, story_maxlen, query_maxlen)

Questionは、1つのEmbeddingレイヤをとおす。

question_encoder.add(Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen))
# --> output: (samples, query_maxlen, embedding_dim)

変換後のInput(Sentences)とQuestionの内積をとって関連ベクトルを算出する。

match = dot([input_encoded_m, question_encoded], axes=(2, 2))
# --> output: (samples, story_maxlen, query_maxlen)