ichou1のブログ

主に音声認識、時々、データ分析のことを書く

TensorFlowメモ(RNNその2)

LSTM(Long short-term memory)を試してみる。

SimpleRNNについては下記ご参照。
tensorflowメモ(RNNその1) - ichou1のブログ

kerasで実装する場合は、単純にRecurrentレイヤーを置き換えるだけ済む。

SimpleRNN
model.add(SimpleRNN(n_hidden,
                    input_shape=(maxlen, 1),
                    kernel_initializer='random_normal'))
LSTM
model.add(LSTM(n_hidden,
               input_shape=(maxlen, 1),
               kernel_initializer='random_normal'))

隠れ層のユニット数は「20」とする。

n_hidden = 20

(参考)model.summary()で表示
f:id:ichou1:20190105183810p:plain

学習で更新されるパラメータ数の内訳(右側はソースコード内の変数)

# lstm_1 Layer
# input gate
weight(input) :  1row * 20col = 20    --> self.kernel_i
weight(state) : 20row * 20col = 400   --> self.recurrent_kernel_i
bias          : 20                    --> self.bias_i

# forget gate
weight(input) :  1row * 20col = 20    --> self.kernel_f
weight(state) : 20row * 20col = 400   --> self.recurrent_kernel_f
bias          : 20                    --> self.bias_f

# input modulation gate(?)
weight(input) :  1row * 20col = 20    --> self.kernel_z
weight(state) : 20row * 20col = 400   --> self.recurrent_kernel_z
bias          : 20                    --> self.bias_z

# output gate
weight(input) :  1row * 20col = 20    --> self.kernel_o
weight(state) : 20row * 20col = 400   --> self.recurrent_kernel_o
bias          : 20                    --> self.bias_o

# dense_1 Layer
weight        : 20row * 1col
bias          : 1
レイヤ構成

f:id:ichou1:20190105191025p:plain
論文「LSTM: A Search Space Odyssey」より引用加工

左側が「SimpleRNN」の説明。
囲み部分が、"input"と"previous timestepの出力"に重みを掛けてバイアスを加算する処理にあたる。
f:id:ichou1:20190105195109p:plain

右側がLSTMの説明、ソースと一緒に見た方が分かりやすいと思われる(tensorflowバージョン1.5を想定)
tensorflow/recurrent.py at r1.5 · tensorflow/tensorflow · GitHub
1816行目以降が該当。

class LSTMCell(Layer):
...
  def call(self, inputs, states, training=None):

f:id:ichou1:20190105201129p:plain

上図に該当するソースコード(説明用に加工)

ソースコードの"h"が図中の"y"に該当、図の表記に合わせている

# tm1 means "t minus one" as in "previous timestep"

# "inputs" shape                  : (None, 1) 
# "self.kernel_*" shape           : (1, units)
# "y_tm1" shape                   : (None, units)
# "self.recurrent_kernel_*" shape : (units, units)

 
# (1)  "x_i","i" shape : (None, units) 
x_i = K.dot(inputs, self.kernel_i)
x_i = K.bias_add(x_i, self.bias_i)
i = self.recurrent_activation(x_i + K.dot(y_tm1, self.recurrent_kernel_i))

# (2)  "x_f","f" shape : (None, units)
x_f = K.dot(inputs, self.kernel_f)
x_f = K.bias_add(x_f, self.bias_f)
f = self.recurrent_activation(x_f + K.dot(y_tm1, self.recurrent_kernel_f))

# (3)  "x_z","z" shape : (None, units)
x_z = K.dot(inputs, self.kernel_z)
x_z = K.bias_add(x_z, self.bias_z)
z = self.activation(x_z + K.dot(y_tm1, self.recurrent_kernel_z))

# (4)  "x_o","o" shape : (None, units)
x_o = K.dot(inputs, self.kernel_o)
x_o = K.bias_add(x_o, self.bias_o)
o = self.recurrent_activation(x_o + K.dot(y_tm1, self.recurrent_kernel_o))


f:id:ichou1:20190105203002p:plain

上図に該当するソースコード(説明用に加工)
#  "c" shape : (None, units)
c = (f * c_tm1) + (i * z)


f:id:ichou1:20190105203450p:plain

上図に該当するソースコード(説明用に加工)
#  "y" shape : (None, units)
y = o * self.activation(c)