yoheikikuta / paper-reading

Notes about papers I read (in Japanese)
156 stars 4 forks source link

[2017] Attention Is All You Need #7

Open yoheikikuta opened 6 years ago

yoheikikuta commented 6 years ago

論文リンク

https://arxiv.org/abs/1706.03762

公開日(yyyy/mm/dd)

2017/06/12

概要

Sequential なデータを扱うときに、recursive や convolutional な構造を使わずに、attention (と positional encoding) のみを使うことで学習時のコストが低くかつ高性能なモデルを構築するという話。 モデルは以下のように図示される。

yoheikikuta commented 6 years ago

まずはモチベーションの話。

recurrent model は前の time step の情報を使って次の step の計算をするので、本質的に並列計算ができないという問題がある。factorization trick という、重み行列を2つの小さい重み行列の積にする業で効率的な計算をするという研究も出ているが、前の time step が必要という原理的な問題は解決してない(factorization trick はちゃんと理解してないのでちょっと想像が入っている)。

attention は recurrent network の文脈において、前の time step の情報のどこに着目するかを取り入れるためのものであったが、この論文では recurrent 構造を取り除いて attention だけで input と output の間の依存性を表現しようというもの。

yoheikikuta commented 6 years ago

まずは embedding と positional encoding の意味。

embedding は何かしらの方法で学習した普通の embedding で 512 次元のベクトルを返す。

positional encoding は recurrent 構造がないために位置情報をモデルに取り込むためのアイデアである。

i は埋め込む次元に対応するもので、i=0,1,...,255 と取ることで 512 次元の出力が得られ、先ほどの embedding と足すことができる。としたいところだが、論文で d_model = 512 と書いてあってちょっと齟齬が生じる。まあこれは len(d_model)=512 で d_model の要素は 0 から始まると考えるのが自然でしょう。

この正弦波は (position) * (波数) と考えれば波数が 1/10000^(2i/d_model) となっていて、i がパラメタであることを考えれば異なる波数の正弦波をなし、与えられた pos (位置) に対してそれぞれが異なる値を返すので位置に類する情報として使えるだろうということだと思われる。

pos に関しては普通に token の順番を整数で与えるということかな?それなら最大で 10000*2π の波長を持っているので、そのサイズ分くらいまでなら位置として使える情報が fixed dimension で扱えるということでしょう。

yoheikikuta commented 6 years ago

その次が主題となる self-attention (multi-head) の部分。

まず、attention をいくぶん抽象的に書いてみることにする。

イメージをしやすくするため具体的なものから考える。 encoder-decoder を考え、decoder における attention を考える。decoder での hidden state $ st $ を求める際に使う attention は $ f(s{t-1}, c_t) $ などと書け、この $ c_t $ は $ \sum a_i h_i $ ( $ h_i $ は encoder の hidden で、$ a_i $ はよく softmax 的に算出される重み) というのが standard な attention だろう。

ここで、自身の一個前の time step の情報を query と呼び、encoder 側の情報を {key, value} で表現する。というのも、こうすると attention は Attention (Q,K,V) = softmax (Q K^T) V と書くことができるためである。どの部分に注目するかというのを query と key 情報の内積の softmax で計算する、という分解である。

yoheikikuta commented 6 years ago

これを更に発展させて multi-head というものにする。

やりたいことは、異なる位置の異なる次元の表現を合わせて使うということである。 具体的には、query と key は $ d_k $ 次元、value は $ d_v $ 次元という異なる次元の表現を採用し、かつパラレルにいくつかを同時に走らせて最後に concat するということをする。式を見た方が早い。

パラレルにする、というのは CNN の channel 数を増やすような意味で表現力を向上させる、ということだろう。結局 $ d_k $ = $ d_v $ と置いてるので最初っから同じにしとけばいいのでは、という気がしないでもないが、一般には拡張できるということだろう。

yoheikikuta commented 6 years ago

ここではいくぶん抽象的な書き方をして、しかも最初のイメージは encoder-decoder から始めていたが、モデルではこの multi-head は encoder-decoder の繋ぎこみ部分だけでなく、encoder 内部、decoder 内部のみでも使われる(self attention)。これは何を意味しているのだろうか?

ここで計算量とかの概念が出てくる。

しかしここの complexity per layer がよく分からない。なんで self attention が $ O(n^2 d) $ なの? sequential op. と maximum per length が $ O(1) $ なのはいいとして、ここを理解すればいいのかはもう少し考える必要がありそう。

yoheikikuta commented 6 years ago

普通に考えるならばここの computational complexity は Softmax( Q K^T ) V だよなぁ。 そして self attention の場合は Nd matrix になると思ってるので、O( N d N N * d ) とかになるのでは?

attention の部分だけ、ということで Q K^T で O( N d N ) ということだったりするのかな? ただそう考えても Recurrent は Q が 1d になると思ってるので O( 1 d * N ) なのでは?

うーむ、どこか間違ってるんだろうけどどこなんだろうか。

yoheikikuta commented 6 years ago

computational complexity の正確な理解はできてないが、少なくとも self-attention が recurrent と比べて layer 毎に O(N) 倍の factor が掛かって、 sequential は O(1) - O(N) という対応なのは分かる。

実験の結果を見ておく。WMT の English-German と English-French のタスク。

低い学習コストで高い性能を達成している。 sequence を如何に効率良く扱うか、というのはここのところの流行りでもあるので、この論文のアイデアが理解できて良かった。 computational complexity に関しては詳しい人がいたら議論してみよう。

yoheikikuta commented 6 years ago

勉強会で発表をしたので理解が一新された。 その時のメモ: https://github.com/yoheikikuta/annotated-transformer/blob/master/paper-memo.pdf

yoheikikuta commented 6 years ago

Positional Encoding

理解は間違ってなかったが実際の実装まで含めてちゃんと理解した。(nbatch, sentence) という入力を考える。ここで sentence はある決まった長さの配列(何tokenまでの文章を扱えるか)で、どの token に対応するかという integer が格納されている。 この配列は要素の index が何番目の token かという位置を表しているが、その位置に対して d_model 次元のベクトルを assign するというのが Positional Encoding がやることである。数式で言えば pos を fix して i の添字がこのベクトルの成分の添字に対応するということである。 なので、これは純粋に位置だけを考慮するもので入力の内容には全く関与しない。

これを embed した入力に足すということで位置の情報を入れ込むということになっている。足すんじゃなくて concat の方が純粋に位置の情報を保持できるかとは思うが、この辺はきっと実験をして足すほうがいいという感じなんだろう。

yoheikikuta commented 6 years ago

Computational Complexity

まず、元々の Encoder-Decoder で Decoder における time step t での attention を $ \sumi a{t,i} h_i = Σ_i softmax( score(s_t, h_i) ) h_i $ として以下と対応させる。

この大文字は行列として、 attention を $ softmax (Q K^T) V $ という行列演算で表現する。 attention のコストを知りたいのでその重みとなる $ Q K^T $ の計算コストを考えてみるという話。 また、簡単な場合としてそれぞれの特徴量の次元は全て d で一定としてかつ {i,t} は n まで走るとする。

K は Encoder 側の情報で、shape は (n, d) 。 Decoder を考えるときには Encoder に関する演算は全て終わっているので、各 time step をまとめて扱える。 Q はある time step では (1,d) のベクトルなので $ Q K^T $ の計算コストが $ O(n*d^2) $ となる。 Decoder の time step は全部で n 個あって、一つ一つ順番にやっていく必要があるため、sequential に O(n) 回の同様の処理が走っていく。

self-attention の話にすると、まず key と value も自分自身から来ることになる。そして recurrent 構造をなくして系列を全て一つの配列にまとめると、 $ Q: (n, d), K: (n, d) $ になり、$ Q K^T $ は $ O(n^2 d) $ という計算になる。 この場合は Decoder の time step は全部まとめて扱えるので、sequential な処理はない(学習時は、の話。予測時は順次予測をしていくので sequential になる。ただ計算コストが重要なのは学習時なのでそれは問題ない)。

yoheikikuta commented 6 years ago

まだちょっと分かってないところ


気持ちとしては positional encoding と同じオーダーだと positional encoding が強く効きすぎてしまうので、その調整として入れているというものになっていそう。 ただし例えば BERT original 実装ではこの factor は含まれていないので、BERT のように positional encoding も学習するパターンでは学習に任せてこういう小細工は入れないというものだと理解。