ichou1のブログ

主に音声認識、時々、データ分析のことを書く

BERTメモ(structural probes)その2

前回の続き。

probe parametersを生成する「run_experiment.py」の動きを追ってみる。
今回、パラメータは「Depth」を指定する。

python structural-probes/run_experiment.py example/config/pad_en_ewt-ud-sample.yaml
pad_en_ewt-ud-sample.yaml(Depth)
probe:
  task_name: parse-depth

outputとなるprobe parameterを決める設定。

model:
  hidden_dim: 1024
probe:
  maximum_rank: 32

出力となるパラメータは以下のとおりとなる。

torch.save(probe.state_dict(), self.params_path)
# (probe.state_dict())['proj'].shape
# torch.Size([1024, 32])

使用しているdataset
GitHub - UniversalDependencies/UD_English-EWT: English data

example/data/en_ewt-ud-sample/en_ewt-ud-train.conllu
example/data/en_ewt-ud-sample/en_ewt-ud-dev.conllu
example/data/en_ewt-ud-sample/en_ewt-ud-test.conllu
example/data/en_ewt-ud-sample/en_ewt-ud-train.conllu
# newdoc id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0001
# text = Al-Zaman : American forces killed Shaikh Abdullah al-Ani, the preacher at the mosque in the town of Qaim, near the Syrian border.
1  Al       Al       PROPN  NNP  Number=Sing  0  root      0:root       SpaceAfter=No
2  -        -        PUNCT  HYPH _            1  punct     1:punct      SpaceAfter=No
3  Zaman    Zaman    PROPN  NNP  Number=Sing  1  flat      1:flat       _
4  :        :        PUNCT  :    _            1  punct     1:punct      _
5  American american ADJ    JJ   Degree=Pos   6  amod      6:amod       _
6  forces   force    NOUN   NNS  Number=Plur  7  nsubj     7:nsubj      _
7  killed   kill     VERB   VBD  Mood=Ind|..  1  parataxis 1:parataxis  _
...

CoNLL-U Format

Depthの場合、7列目の「HEAD」を使う。

HEAD: Head of the current word, which is either a value of ID or zero (0).

深さは、root(ここでは[Al])までのnode数で、[American]の場合、「3」になる。

[American] -> [forces] -> [killed] -> [Al]
ja_gsd-ud-train.conllu
# sent_id = train-s1
# text = ホッケーにはデンジャラスプレーの反則があるので、膝より上にボールを浮かすことは基本的に反則になるが、その例外の一つがこのスクープである。
1  ホッケー   ホッケー   NOUN   NN  _  8   iobj   _  SpaceAfter=No
2  に        に        ADP    PS  _  1   case   _  SpaceAfter=No
3  は        は        ADP    PK  _  1   case   _  SpaceAfter=No
4  デ_プ_    デ_プ_     PROPN  NNP _  6   nmod   _  SpaceAfter=No
5  の        の        ADP    PN  _  4   case   _  SpaceAfter=No
6  反則      反則       NOUN   NN  _  8   nsubj  _  SpaceAfter=No
7  が        が        ADP    PS  _  6   case   _  SpaceAfter=No
8  ある      ある       VERB   VV  _  24  advcl  _  SpaceAfter=No
...
24 なる      なる      VERB   VV  _  34  advcl  _  SpaceAfter=No
...
34 スクープ  スクープ    NOUN   NN  _  0   root   _  SpaceAfter=No
...
[ホッケー] -> [ある] -> [なる] -> [スクープ]



各senteceをencodeした結果は以下に格納されている。

example/data/en_ewt-ud-sample/en_ewt-ud-train.elmo-layers.hdf5
example/data/en_ewt-ud-sample/en_ewt-ud-dev.elmo-layers.hdf5
example/data/en_ewt-ud-sample/en_ewt-ud-test.elmo-layers.hdf5
structural-probes/regimen.py
# Diskからencode結果を読み込む場合は、そのまま使う
word_representations = model(observation_batch)
# observation_batch.shape
#  --> torch.Size([1, seq_len, 1024])
# word_representations.shape
#  -->torch.Size([1, seq_len, 1024])


probe parameterをトレーニングする。

structural-probes/probe.py

encode結果にprobe parameterを掛けて、

# Computes all n depths after projection for each sentence in a batch.
transformed = torch.matmul(word_representations, self.proj)
# self.proj.shape
#  -->torch.Size([1024, 32])
# transformed.shape
#  -->torch.Size([1, seq_len, 32])

各tokenごとに、要素の二乗和を計算

# Computes (Bh_i)^T(Bh_i) for all i
norms = torch.bmm(transformed.view(seq_len, 1, rank),  # (seq_len, 1, 32)
                  transformed.view(seq_len, rank, 1))  # (seq_len, 32, 1)

# --> torch.Size([seq_len, 1, 1])

# A tensor of depths of shape (batch_size, max_seq_len)
norms = norms.view(1, seq_len)

教師データとして与えたDepthとの損失を最小化するようにトレーニングする。