BERTメモ(structural probes)その2
前回の続き。
probe parametersを生成する「run_experiment.py」の動きを追ってみる。
今回、パラメータは「Depth」を指定する。
python structural-probes/run_experiment.py example/config/pad_en_ewt-ud-sample.yaml
pad_en_ewt-ud-sample.yaml(Depth)
probe: task_name: parse-depth
outputとなるprobe parameterを決める設定。
model: hidden_dim: 1024 probe: maximum_rank: 32
出力となるパラメータは以下のとおりとなる。
torch.save(probe.state_dict(), self.params_path) # (probe.state_dict())['proj'].shape # torch.Size([1024, 32])
使用しているdataset
GitHub - UniversalDependencies/UD_English-EWT: English data
example/data/en_ewt-ud-sample/en_ewt-ud-train.conllu example/data/en_ewt-ud-sample/en_ewt-ud-dev.conllu example/data/en_ewt-ud-sample/en_ewt-ud-test.conllu
example/data/en_ewt-ud-sample/en_ewt-ud-train.conllu
# newdoc id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000 # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0001 # text = Al-Zaman : American forces killed Shaikh Abdullah al-Ani, the preacher at the mosque in the town of Qaim, near the Syrian border. 1 Al Al PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 - - PUNCT HYPH _ 1 punct 1:punct SpaceAfter=No 3 Zaman Zaman PROPN NNP Number=Sing 1 flat 1:flat _ 4 : : PUNCT : _ 1 punct 1:punct _ 5 American american ADJ JJ Degree=Pos 6 amod 6:amod _ 6 forces force NOUN NNS Number=Plur 7 nsubj 7:nsubj _ 7 killed kill VERB VBD Mood=Ind|.. 1 parataxis 1:parataxis _ ...
Depthの場合、7列目の「HEAD」を使う。
HEAD: Head of the current word, which is either a value of ID or zero (0).
深さは、root(ここでは[Al])までのnode数で、[American]の場合、「3」になる。
[American] -> [forces] -> [killed] -> [Al]
(参考)日本語版のデータセット
Universal Dependencies
GitHub - UniversalDependencies/UD_Japanese-GSD: Japanese data from the Google UDT 2.0.
ja_gsd-ud-train.conllu
# sent_id = train-s1 # text = ホッケーにはデンジャラスプレーの反則があるので、膝より上にボールを浮かすことは基本的に反則になるが、その例外の一つがこのスクープである。 1 ホッケー ホッケー NOUN NN _ 8 iobj _ SpaceAfter=No 2 に に ADP PS _ 1 case _ SpaceAfter=No 3 は は ADP PK _ 1 case _ SpaceAfter=No 4 デ_プ_ デ_プ_ PROPN NNP _ 6 nmod _ SpaceAfter=No 5 の の ADP PN _ 4 case _ SpaceAfter=No 6 反則 反則 NOUN NN _ 8 nsubj _ SpaceAfter=No 7 が が ADP PS _ 6 case _ SpaceAfter=No 8 ある ある VERB VV _ 24 advcl _ SpaceAfter=No ... 24 なる なる VERB VV _ 34 advcl _ SpaceAfter=No ... 34 スクープ スクープ NOUN NN _ 0 root _ SpaceAfter=No ...
[ホッケー] -> [ある] -> [なる] -> [スクープ]
各senteceをencodeした結果は以下に格納されている。
example/data/en_ewt-ud-sample/en_ewt-ud-train.elmo-layers.hdf5 example/data/en_ewt-ud-sample/en_ewt-ud-dev.elmo-layers.hdf5 example/data/en_ewt-ud-sample/en_ewt-ud-test.elmo-layers.hdf5
structural-probes/regimen.py
# Diskからencode結果を読み込む場合は、そのまま使う word_representations = model(observation_batch) # observation_batch.shape # --> torch.Size([1, seq_len, 1024]) # word_representations.shape # -->torch.Size([1, seq_len, 1024])
probe parameterをトレーニングする。
structural-probes/probe.py
encode結果にprobe parameterを掛けて、
# Computes all n depths after projection for each sentence in a batch. transformed = torch.matmul(word_representations, self.proj) # self.proj.shape # -->torch.Size([1024, 32]) # transformed.shape # -->torch.Size([1, seq_len, 32])
各tokenごとに、要素の二乗和を計算
# Computes (Bh_i)^T(Bh_i) for all i norms = torch.bmm(transformed.view(seq_len, 1, rank), # (seq_len, 1, 32) transformed.view(seq_len, rank, 1)) # (seq_len, 32, 1) # --> torch.Size([seq_len, 1, 1]) # A tensor of depths of shape (batch_size, max_seq_len) norms = norms.view(1, seq_len)
教師データとして与えたDepthとの損失を最小化するようにトレーニングする。