yoheikikuta / paper-reading

Notes about papers I read (in Japanese)
156 stars 4 forks source link

[2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale [paper-reading] #56

Open yoheikikuta opened 3 years ago

yoheikikuta commented 3 years ago

論文リンク

https://arxiv.org/abs/2010.11929

公開日(yyyy/mm/dd)

2020/10/22

概要

画像分類を (convolution を使わず) transformer 型モデルで実施した。 Visual Transformer (ViT) というモデル名。 入力は計算量を抑えるために pixel 毎に全て attention を張るのではなく、patch に分け、各 patch を flatten した後に線形変換をしてある次元のベクトルにしたものを全 patch 文並べて embedding を作っている。当然 learnable positonal encoding も足してる。 事前学習は class token の出力に MLP を生やして教師あり学習として実施している(教師なしの方法も提案しているが簡素すぎるからか性能は低い)。 CNN based な過去手法よりも計算コストがだいぶ低い(1/4 TPU days くらいになる)が、ImageNet くらいのサイズのデータサイズだと劣後し、10倍, 100倍としていくことで同等以上の性能を発揮するようになる。

GitHub repository は https://github.com/google-research/vision_transformer

yoheikikuta commented 3 years ago

vision transformer の存在を知ってたけどちゃんと読んでいなかった。 DALL·E とかも出て vision への transformer 型モデルの適用をどうやっているのか真面目に理解しようと思って読むことにした。 最近 scalability が流行っててその辺にも言及がありそうだったのもモチベーションの一つ。

yoheikikuta commented 3 years ago

この研究のモチベーション自体は全然面白くなくて、Transformer 型のモデルが NLP でめっちゃ優秀な性能を発揮してるので出来るだけそれをそのまま vision の方に持ってこようというもの。

実験結果とかを考慮しない場合、よくこれが研究のモチベーションになるよなという感じがする(めちゃくちゃ乱暴に言えば他分野で流行ってるモデルをこっちの分野にも適用してみます、という話なので)。

もう少し話を押し進めると、Transformer 型モデルには「なぜか」べき乗則(test loss がモデルサイズもしくはデータサイズのべき分布になっている)が成り立つので、それが vision においても成立するかということを確かめたいというものになっている。

Transformer 型モデルのべき乗則の背後に何があるのか、それを理解することで(現状理解されているべき乗則を超えた)有効なモデルが作れるのか、にはかなり興味がある。これは個人的な興味で、この論文ではそれについて何かを示唆しているわけではない。

yoheikikuta commented 3 years ago

あとちょっと分かってないのは Transformer 型モデルの計算効率性というのも推しているところ。 Transformer の出自を考えたときに、RNN 型のモデルと比べて並列計算との親和性が高いので(GPU や TPU を用いた並列計算が主流になった現在において)計算効率が良いというのは分かる。

vision は元々 convolution が主流でこれは例えばチャンネル毎に並列計算可能なので、vision の文脈でいうところの計算効率性というのはどういうものなのだろうか。 ResNet とかだと層の数が多くてこれは前の層を計算しない限り後の層が計算できないので、その意味で並列性が犠牲になっているけど、Transformer 型のモデルなら(層ごとのパラメタ数は多いけど)層の数は多くないものの方が計算効率はいいということかな?

yoheikikuta commented 3 years ago

モデルの話をする。 基本的な構造は Transformer の encoder 部分を使うという感じなので難しいところは特にないが、下記二点だけは扱いが異なる。

モデルの概念図は以下の通り。

Transformer Encoder 部分はどこに Layer Norm を入れるとかで選択肢があるけどここに書いてある情報だけで十分なので割愛。

モデルの input に関して。まず元の画像は $ x ∈ R^{H \times W \times C} $ とする。 ここは色々選択肢がありうる。NLP では token 毎に 768 dim とかの分散表現に embed していた。画像を同じ仕組みに持ってくる場合に、どういう単位を token とするのがいいだろうか?pixel 毎に embedding することも考えられるが、1 pixel を高次元に embedding してもあまり意味がないだろうし、token 数も画像サイズだけ出てくるのでデカすぎて却下だろう。画像全体をまとめて embedding してしまったら self-attention もクソもないのでこれも却下。 ということで元画像をある程度のサイズの patch に分ける: $ x ∈ R^{N \times P \times P \times C} $ P はパッチのサイズで 16 などと選び、これが決まれば $ N = HW / P^2 $ で N は定まり、例えば (224, 224, 3) の画像の時は N = 196 となる。カラー情報は当然そのパッチに付随するものなので、やりたいのは $ P \times P \times C $ を一つの token としてそれを N 個並べるということだ。 この $ P \times P \times C $ に learnable な行列を掛けて hidden dim D (=768) に embed するということで入力 token 列を形成する。 分類のために頭に class token として同じ D 次元の embedding を加え、あとはやはり(simple な)learnable な poisiton embedding を足せば、入力 embedding は完成となる。ここまでできればモデルの話は Transformer とか BERT が分かってれば完全に理解したと言ってよいだろう。

この patch にして embedding する、というのはなかなか良い塩梅なのではないかと感じる。 画像の局所的な関係性は patch で捉えて、大域的な関係性は self-attention で捉えるという形になっている。 CNN と本質的に異なることは明示的な並進不変性などが取り込まれていないことだ。convolution の利点の一つはそこなわけで、これさえも人間による inductive な bias として取り込むのではなく、embedding でデータからよしなに学習してもらえば十分だろう、という考えだろう。

これはなかなか過激だし、面白いし、個人的にはちょっと悲しい。 人間の持つ知識を convlution という形で取り込んでいたわけだが、そういった明示的な規則ではなく多様なデータから自然と学べば十分だろうという態度の表明であるわけなので。

ちなみにこの論文では convolution を使って feature map を作ってからそれに対して embedding をするという hybrid モデルも試している。 具体的には入力に対して ResNet 構造を適用(これらのパラメタも学習対象)して、得られた feature map を embedding して positon embedding を足して同じように Transfomer Encoder に接続するというモデルである。 これはパラメタ数も増えるし計算効率性も悪いし、モデルとしてはイマイチな感じだが、convolution が本質的に重要であるか、という観点で試しているのだと理解している。

ちょっと長くなったので pre-training に関しては別で書くか。

yoheikikuta commented 3 years ago

自分が Vision Transformer の話を最初に聞いた時に思ったのは「へ〜 pre-training はどうやってるんだろう?」というものだった。 というのも、BERT で有効性が強烈に示された masked language model のような学習方法を良い感じに vision で実施する方法が思い浮かばなかったからだ。

NLP においては分布仮説を信じて masked language model で学習するというのは bet するのに十分有望な戦略だと思うけど、画像の場合にもそれってやれるんだっけ?とか思っていた。 画像の一部が欠損している状態でそれを類推するとかいうのはそこまで突拍子もない感じはしないけど、画像の patch 情報を全部再現するように pixel-wise で二乗誤差使って学習するのはいかにもダメそうだよな〜とか。

結論から言うと、Vision Transformer では unsupervised pre-training はしてなくて、画像分類ラベルがついている画像データで supervised training をしていた!マジかよ! ということでさっきのモデルの図には class token の出力に MLP head がついていて分類問題を解くようになっていたというわけだった。pre-training は Imagenet 1k, Imagenet 21k, JFT-300M とかで 100万 ~ 3億 枚の画像データを使っている。データセットによってクラス数が違うので MLP layer はそれに応じて作って pre-training して、downstream タスクを解く時はこの MLP 部分をすげ替えて fine-tuning するという学習方法になっている。

これは期待外れ(自分が勝手に期待してただけだが)の結果だが、この論文では unsupervised な方法も試した上で最終的に supervised の方がいいので supervised pre-training をしている。 どのようにやったかというと以下。

ここはもうちょっと工夫しがいがありそうだけど、なんか論文ではこの方向性にそんなに熱心ではない感じがする。 画像データは分類用のデータセットが大量にあるから supervised でデータ量確保できるということなのだろうか?ここを unsupervised でうまくできるようにするとべき乗則を活かして大規模学習を進める、ということがやりやすいので潤沢な計算資源を持つ人々はやりたくなりそうなもんだけど。


ちなみに PyTorch 実装 https://github.com/lucidrains/vit-pytorch では Bootstrap Your Own Latent (BYOL) https://arxiv.org/abs/2006.07733 という手法で unsupervised 学習ができるようになってる。 チラ見したら、ネットワークを複製して、一方はもう一方のパラメタの exponential moving average とかにして出力が一致するように学習(augmentation とかを使いつつ)していくというものらしい。これだけ聞くとそんなにうまいこといくの?という感じがするが、本題とは外れるので機会があったらまた読んでみることにする。

yoheikikuta commented 3 years ago

実験結果を見る。

その前に JFT-300M という画像データセットを知らないのでちょっと調べておく。 https://arxiv.org/abs/1707.02968 の論文で使われているもので、18291 クラスの 300 million 画像データセットのこと。そういえばこの論文が出たときに 3 億枚の画像云々という話を聞いたな。 これどこでダウンロードできるのかと思ったら、Google 内部でのみ使っているデータなのね... はぁ... データサイズが重要という話なのにその再現性がない(データ公開されてても計算量的にほとんど誰もできないけど)というのは悲しいっすな...

このデータで pre-training して、各種 downstream タスクを解いた結果が以下。 例えば ViT-H/14 の 14 は patch サイズを意味している(ので例えばこれが大きくなると、入力画像サイズが一定の場合に token 数が減るということになる)。 BiT (Big Transfer) や Nosy Student は先行研究で性能が高いモデル。

精度が驚くほど向上したというものではなく、十分デカいモデルを使えば先行研究で性能の高いモデルと同等以上の性能を発揮できる、というものになっている。

注目すべきは学習が終わるまでに必要な計算時間で、ViT-H/14 で先行研究の 1/4 以下になっている(TPUv3-core-days で 2.5k なので普通の人には手が出せるレベルではないが)。これはバッチサイズとか epoch 数とかを合わせて学習しているので、学習の条件としてはモデル以外は同等になっているはず(他にも当然 weight decay とかも影響するのでその辺もできるだけ同じ条件にしてる)。 それで 4 倍速く学習できるにようになってるので確かに効率的に学習できていると思われるが、その違いがどこから来ているのかは結構非自明(というけ明確には分かってない)。

まず、パラメタ数は ViT-Huge で 6.32 億で、BiT は 9.3 億、EfficientNet-L2 は 4.8 億。 EfficientNet-L2 はパラメタ数は少ないけど学習は遅いのは Noisy Student 使ってて teacher と student の 2 つのネットワークを使うことになるから遅いということかな(なので inference が遅いということは特にないはず)。 BiT はパラメタ数多くて演算回数が多いので遅い?のかと思ったけど、学習時に必要になるトータルの演算数はこの論文からは読み取りづらい... 例えば appendix の Table 6 に以下の表があるが、これは scaling の実験をしたもので、上で載せた結果と同じものがない... BiT つまり ResNet152x4 は表にはないが、とはいえ幅を 2 倍にしても演算数が 4 倍程度という感じなので、ViT-H/14 と遜色なさそう。

そしたら縦にたくさん積んでるので層毎の計算の依存が多いので遅いのかな〜と思ったけど、(学習じゃないけど)inference は同じくらいの速度だった。以下の図の左から ViT-H/14 と R152x4 は同等程度。右からはバッチサイズに関しては ViT-H/14 の方が多く持てそうだけど、学習時にはバッチサイズ揃えて計算してるはずなのでこれは効いてないはず。

ということで自分にはなぜ ViT の学習が他と比べてこれだけ速いのかはイマイチわからなかった。 ここは結構重要だと思う(精度はほぼ同等でしかないので)んだけど、論文でもちゃんと書いてない(と思われる)のでう〜むという感じ。誰か詳しい人に教えてもらいたい。

yoheikikuta commented 3 years ago

データセットサイズに対する scaling について。 ViT では pre-training は supervised なので、ここでは ImageNet 1k, 21k, JFT-300M のそれぞれで pre-training した場合に downstream タスクの性能がどうなったかで検証している。

結果は以下の図の左。 BiT に関しては色々なモデルサイズで実施し、その上限と下限を示している。点にするとごちゃごちゃになっちゃうからっすな。 ImageNet 1k では明らかに BiT より低性能だしデカいモデルだから良いというわけでもないが、データセットを大きくするとモデルがデカい順に並びそして BiT と同等以上の性能を発揮するようになっている。

右の図も学習サンプル数を増やすことで ViT の大きいモデルは特に 30M-100M で大きく性能を伸ばしていることが分かる。

同じような scaling を横軸を学習時の総演算回数でプロットしたものが以下の結果。

これは hybrid (token を作る前に resnet 構造を入れて feature map を抽出するモデル) もプロットされているが、十分に学習する前には convolution を明示的に入れたものが性能に寄与しているが、十分に学習すると convolution を入れても効果は特にないという結果になっている。 これはなかなか impressive な結果。理屈としては十分データがあって学習すれば確かに convolution の構造を明示的に入れなくてもそれに相当するものを学習できてもよさそうというのは分かるが、それがいままさに示されたというのはやはり驚き。

この図は Table 6 のモデルに対するプロットで R152x4 がないというのはやっぱり納得がいかないが、結果としてはデカいが強いで scaling してそうというものになってる。

論文では ViT は性能が saturate してなさそうだからもっと scaling させたいねって言ってるけど、それは BiT も同じだと思うけどな。 色々実験してるけどなんか結果の見せ方が ViT を無駄に良く見せようとしてる感じがして気になる。

yoheikikuta commented 3 years ago

positional embedding に関しては

という結果。 2D 的に取り扱うとかいかにも有効そうだけど、そんなん気にせず 1D で token 作って並べて突っ込んどけばいいんや!ってことらしい。 これは自分の予想に反してたのでやや驚き。結論としてはシンプルなので分かりやすいけど。

yoheikikuta commented 3 years ago

他にも色々やってるけど、自分が一番興味があった部分はチェックできた(理解できてないところはあるけど)ので、だいたいこんなもんかな。

patch 化するところだけちゃんと把握できてれば Transformer 型モデルそのものに近い構造なので、これは確かに text と一緒に使いたくなってくるね。それを実現しておどくべき成果を発揮したのが DALL·E ということだと思うので、次はそこを理解したいね(論文は出てないけど)。