e4exp / paper_manager_abstract

0 stars 0 forks source link

A Dot Product Attention Free Transformer #665

Open e4exp opened 2 years ago

e4exp commented 2 years ago

概要: Dot Product Attention Free Transformer (DAFT)を紹介します。 これは、Transformer (Vaswani et al., 2017) 自己注目におけるクエリ・キーのドット積を排除する効率的な変形です。 その核となるアイデアは、クエリの各次元、キーと値に対して分解可能なアテンションマップを構築することです。 この構成により、アテンション・テンソルを明示的に計算したり保存したりする必要のない実装が可能になります。 DAFTレイヤーは、コンテキストサイズと特徴量の次元の両方に対して線形のメモリ計算量を持ち、大きな入力サイズとモデルサイズの両方に対応しています。 また,DAFT-convは,グローバルな接続性を維持しつつ,局所性と空間的な重みの共有を利用したモデルのバリエーションです。 ImageNet-1Kの分類、CIFAR10およびEnwik8という2つの自己回帰モデリングタスクで実験を行いました。 その結果,DAFTはすべてのベンチマークにおいて競争力のある性能を発揮し,同時に優れた効率性を実現していることがわかりました。

e4exp commented 2 years ago

1 はじめに

Transformers(Vaswani et al., 2017)に代表される自己注目メカニズムは、言語理解(Devlin et al., 2018; Radford et al.)やコンピュータビジョンアプリケーション(Chen et al.; Dosovitskiy et al., 2020; Touvron et al. 畳み込みニューラルネット(CNN)やリカレントニューラルネット(RNN)のような古典的なモデルアーキテクチャとは異なり、Transformerはシーケンス内のすべてのペアの要素間で直接相互作用が可能であり、長期的な依存性を捉えるのに特に強力です。 トランスフォーマーは、Multi-head Attention(MHA)の上に構築されています。 MHAは、入力に対してドットプロダクトアテンションをh回実行し、チャネル次元に沿って出力を連結します。 これにより、空間的な複雑さはO(hT^2 )となります(Tは入力シーケンスの長さです)。 このように、Transformerは、長いシーケンスと大きなモデルサイズに同時に対応することが課題となっています。 最近では、Transformerのスケーラビリティ問題の解決に向けた研究が数多く行われている(Childら、2019年、Kitaevら、2020年、Raeら、2020年、Wangら、2020b、Katharopoulosら、2020年、Tayら、2020a、Choromanskiら、2020年)。 ここでの共通のアイデアは、スパース性(Child et al., 2019; Zaheer et al., 2020)、ローカリティセンシティブハッシュ(Kitaev et al., 2020)、低ランク分解(Wang et al., 2020b)、カーネル近似(Katharopoulos et al., 2020; Choromanski et al., 2020)などに及ぶ技術を用いて、フルアテンション操作を近似することである。

本論文では、標準的なドットプロダクトアテンションを使用せず、また近似しない計算モジュールを提案する。 したがって、我々のモデルをDot product Attention Free Transformer (DAFT)と名付けます。 DAFTは2つの単純なアイデアを追求しています。

1)アテンションヘッドの計算量を加法的かつ分解可能にすることで減少させる 2)最小限のコストで入力次元と同じになるようにヘッドの数を増加させる。

我々の実装では、DAFTは、key&valueを要素単位で乗算し、(学習したペアワイズ位置バイアスを用いて)時間次元に沿って縮小し、最後にクエリを要素単位で乗算するように計算グラフを再編成します。 これにより、アテンション・マトリクスの明示的な計算が不要となり、空間的な複雑さはO(T d)となります。

DAFTに密接に関連する研究として、「線形注目」トランスフォーマー(LAT)がある(Katharopoulos et al., 2020; Choromanski et al., 2020; Peng et al., 2021; Bello, 2021)。 LATは、非線形性をexpからidentityに変更した、ドットプロダクトアテンションの特殊なケースと解釈できます。 この変更により、クエリと対話する前にquery&keyのコンテキストリダクションが起こるという、同様の計算の再編成が可能になります。 DAFTとLATには2つの重要な違いがあります。

1) DAFTは標準的なMHAと同じexp nonlinearityを維持していますが、LATはこれを取り除きます。経験的に、exp非線形性の除去はしばしば性能低下の原因となります。 2) LATはO(T dd' )という空間的な複雑さを持つ。ここでd'はクエリとキーの投影の次元である。このため、LATはモデルサイズが大きい場合には高価になります。一方、DAFTはO(T d)の空間的複雑さを持ち、大きな入力と大きなモデルの両方に友好的です。

表1に比較の概要を示します。

image

基本的なDAFTの定式化に加えて,位置バイアスの局所性と重みの共有を利用したDAFT-convを提案します。 DAFT-convは、TransformerとConvNetsの両方の利点を受け継いでいます。 DAFT-convはTransformerと同様に配線され、クエリ、キー、値の間のグローバルな接続性と乗法的な相互作用を維持しています。 また、ConvNetsのパラメータ効率、スパースな計算、並進等変則性の特性を享受しています。 DAFT-convは、我々がテストした全てのタスクにおいて、基本バージョンよりも優れた性能と効率性を提供することを示します。

画像分類、画像の自動回帰モデリング、文字レベルの言語モデリングのタスクでDAFTの実験を行いました。 その結果、DAFTは、標準的なTransformerやその他のバリエーションに匹敵する、あるいは凌駕する競争力のある性能と、優れた効率性を提供することが分かりました。 また、DAFTのいくつかの設計上の選択に対するアブレーションの研究を行い、Transformerとの互換性、カーネルサイズの選択に対するロバスト性、可変サイズの入力など、DAFTのユニークな特性について議論しました。

e4exp commented 2 years ago

2 背景

2.1 マルチヘッドアテンション

ここで、自己注意モードにおけるマルチヘッドアテンション(MHA)の操作を紹介する。 ここでは、入力シーケンスをX∈RT ×dとし、T、d、hはそれぞれシーケンス長、特徴次元、ヘッド数hである。 また,上付き文字のiはi番目の注意ヘッドを,下付き文字のtはt番目の位置を表している. MHAは、各ヘッドiに対して、次のように定義されるスケールドット・プロダクト・アテンションを行う。

image

ここで、Wi Q∈Rd×dk , Wi K∈Rd×dk , Wi V∈Rd×dvはヘッドiの線形変換であり、Y i t∈Rdvはt番目のクエリロケーションに対するi番目のアテンションヘッドの出力である。 dk, dvはそれぞれキーとバリューの次元である。 MHAはh個のアテンションヘッドの出力をチャネル次元に沿って連結し、特徴量次元hdvを得る。 特に言及しない限り、dk = dv、h = d dk と仮定します。 これは、クエリ、キー、値が各ヘッド内で同じ次元であり、出力次元が入力の次元と一致することを意味します。

2.2 線形注意式

式1のexp非線形性を除去することで、(Katharopoulos et al, 2020; Choromanski et al, 2020)で使用されている線形注意を生じさせることができる。 この特別な形のドットプロダクトアテンションは、次のように計算の順序を並べ替えることができます。

image

ここでは、QとKの次元がRT ×d'であることを許容し、d'はdとは無関係に選択できる(例えば、Performer (Choromanski et al, 2020)では、d' = O(d log d)とすることが推奨されている)。 また、QとKには、必要に応じて追加の線形投影と非線形性が含まれていることを想定している。

e4exp commented 2 years ago

3 METHODOLOGY

3.1 DOT PRODUCT ATTENTION FREE TRANSFORMER

ここでDot product Attention free transformer (DAFT)を定義するが、これはTransformerの他のアーキテクチャ面を変更する必要のないMHAのプラグイン置き換えである。 基本的なアイデアは、ドット積アテンションを、効率的な実装に相当する付加的アテンション(Bahdanau et al. 2014 まず、以下のように定義されたアテンションヘッドを考えることから始める。

image

ここで、w∈RT ×Tは、学習されたペアワイズポジションバイアスのセットである。 ここで、注目スコア(exp前)は、クエリ、キー、そして「静的注目」スコア(w)の加算合成になります。 このようにすることで、自動的にアテンションヘッドの数が特徴次元と等しくなるようにもしています。 この新しい形式は、1つのアテンションヘッドの表現力を、より多くのヘッドと交換するものだと解釈できます。 しかし、注意の正規化メカニズムのために、クエリtermは分子と分母から相殺されます。 そこで、若干の修正を加え、次のようにします。

image

ここで,σqはシグモイドの非線形性である. そうすることで、分母からQi tを取り除くことができ、LSTMのHochreiter & Schmidhuber (1997)における出力ゲートの役割に似ています。 これで、DAFTの完全な計算を次のように定義することができます。

image

ここで、は要素和積、Qt, Kt 0 , Vt 0 ∈Rd 。

image

効率的な実装。

残る問題は、DAFTをいかに効率的に実装するかということです。 トランスフォーマーで行われているように式4の注目度を明示的に計算することは、サイズT 2dのテンソルを格納する必要があるため、実現不可能であることは明らかです。 幸いなことに、DAFTの組成性により、線形注目の場合と同様の計算の再配置が可能です。 要するに、exp(x + y) = exp(x) exp(y)という単純な等式を利用して、計算を要素ごとの乗算/除算と行列テンソルの乗算に分解します。 さらに、wの素朴なパラメータ化は、パラメータ集約的である。 そこで我々は、wを次のように因数分解した形で採用する。

image

ここで,n は小さな埋め込み次元(例えば 128)である. この単純な因子分解は,パラメータ数を大幅に削減するだけでなく(2T n vs T 2 ),経験的にトレーニングとテストの両方でモデルの性能を向上させることができます. このバージョンをDAFT-fullと呼び,PyTorchスタイルの疑似コードをAlgorithm 1に示します。

image

3.2 DAFT-CONV

DAFT-fullでは,wは2つの位置の間の静的なバイアスに符号化されます。 したがって、MHAと同じように、hセットの位置バイアス{wi}i=1...hを学習し、それぞれが特徴次元のサブセットと相互作用する「マルチヘッド」バージョンを考えるのは簡単です。 しかし、wは自由パラメータとして学習されるため、hを大きくするとパラメータの効率化がすぐに課題となる。 我々はこのジレンマに対して、wiをローカルにすると同時に、ポジション間でパラメータを共有するというシンプルなソリューションを提案する。

より正確には、入力のネイティブ構造(1dまたは2d)を利用し、位置ペア(t, t' )が指定されたウィンドウサイズ2s + 1内にあれば、それらの相対位置に基づいてゼロ以外の値を割り当てる(ウィンドウサイズは奇数であると仮定する)。 1dの場合(2dの場合も同様に導き出せる)、これは次のようになります。

image

ここで,r^i∈R^{2s+1},i=1,2,...,hは,相対位置ベクトル1です。 ここで,DAFTのマルチヘッド,畳み込みバージョンであるDAFT-convを紹介します。 DAFTconv は,畳み込みプリミティブを用いて,次のように実装できます(再び 1d 畳み込みを例にします)。

image

ここで,hはヘッドの数,dw-conv(x, r)は,x∈RT ×dを入力信号,r∈R(2s+1)を畳み込みフィルタとし,「同じ」パディングで1次元に沿って畳み込む,深さ方向に分離可能な1次元畳み込みを表す。

DAFT-fullと比較して,DAFT-convはパラメータと計算効率の両方を実現しています。 さらに,DAFT-convは,多くの領域で有用であることが証明されているConvNetsと同様に,局所性と並進等変量の帰納的バイアスを享受します。 また、DAFT-convには特徴的な特性があります。 それは、ウィンドウサイズの選択に関わらず、任意の2つの位置間のグローバルな接続性が維持されることです。 DAFT-convは窓サイズが非常に小さくても強力な性能を発揮することを実験的に検証しています。

実装。

学習の際には、畳み込みフィルタrの再パラメータ化を採用し、各ヘッドiに対して、次のようにする

image

ここで,γ∈Rd,β∈Rdは学習可能なゲインとバイアスのパラメータで,いずれも0に初期化されています. これにより,学習が加速され,一貫して性能が向上することがわかりました. PyTorchスタイルの疑似コードをAlgorithm 2に示します。簡略化のためにh = dと仮定しています。

e4exp commented 2 years ago

5 実験

画像分類(項5.1)、画像の自己回帰モデリング(項5.2)、文字レベルの言語モデリング(項5.3)の3つのタスクについて実験を行いました。 すべての実験は、特定のタスクのためのベースラインTransformerアーキテクチャを入手し、注目モジュールをDAFTモジュールに置き換えるという、プラグ・アンド・プレイ方式で設計されています。 初期化や学習率のスケジューリングなどのハイパーパラメータも、Transformerの対応する部分から直接継承しています。 特に断りのない限り、すべての実験は8×V100 GPUマシンで行われました。

5.1 画像分類

まず、画像分類タスクに焦点を当てて、DAFTのエンコーダバージョンをテストします。 Vision Transformerアーキテクチャ(Dosovitskiy et al., 2020)を採用し、Imagent 1K分類データセットで実験を行いました。 DeiT(Touvron et al., 2020)の学習設定とハイパーパラメータ(バッチサイズ、データ増強、正則化、学習率スケジューリング)を採用しています。

簡単に説明すると ViTは画像を16×16の重ならないパッチに分割し、各パッチをトークンの埋め込みの等価性に共有の重みをつけて線形投影する。 学習されたクラストークンは、結果として得られた表現に付加され、長さ T = 1 + (H/16 )/ (W/16)のシーケンスになります。

線形分類ヘッドを最終層のクラストークンに付けて、最終出力を得る。 モデル構成の詳細は(Dosovitskiy et al., 2020)を参照してください。 全ての実験はImageNet-1Kデータセットを用いて、余分なデータを使わずに行われています。 このタスクではシーケンスサイズが比較的小さいので(224 × 224の入力サイズに対してT = 197)、まずDAFT-fullで実験する。 因数分解された位置バイアスの隠れ次元をn = 128とします。 また,DAFT-convを用いて実験を行います.

この設定では,位置エンベッディングとクラストークンの使用を削除し,最終層の出力の後にグローバルアベレージプーリングを適用して,分類線形層に供給します. これにより、モデルの設計が簡単になるだけでなく、DAFT-convを完全な畳み込み型にすることができました。 ベースラインとして、DeiT "small" (L=12, d=384, h=6, DeiT-Sと表記)の構成を採用しています。 また,MLP-Mixer (Tolstikhin et al., 2021)と比較します. ここでは,注目層を隠れ層サイズDS = 4TのMLPに置き換えます. これは、対応するMHA層のパラメータ数とほぼ一致します。 また、Linear Transformer (Katharopoulos et al., 2020)とPerformer (Choromanski et al., 2020)という2つの線形注意モデルと比較します。 どちらもQ, K次元はDeiT-Sと同じにし、非線形性はそれぞれの論文で推奨されている1+eluとreluを採用しています。 その結果を表2に示します。

まず、DAFT-fullは、ベースラインのTransformer DeiT-Sと同等の性能を達成していますが、メモリフットプリントが改善され、速度も同等であることがわかります。 DAFT-convは、パラメータ数が同等または少ないにもかかわらず、両構成のトップ1精度を大幅に向上させています。 また、DAFTはLinear TransformerとPerformerよりもトップ1精度で優れています。 また、経験的に、DAFTはTransformerよりも収束が速いことがわかっています。 実際、DAFT-convモデルを全エポックのうち23回だけ学習することで、DeiT-Sの性能に匹敵することができました。

視覚化。

図1に示すように、DAFT-fullとDAFT-convが学習した位置の偏り(正確にはexp(w)-1)を可視化することも試みました。 興味深い局所的で対称的なスパースパターンが現れることに注目してください。 付録では、位置バイアスを正則化することで、より多くのスパース性を得ることができることを示しています。 また、DAFT-convの極端なバージョンを示します。 ここでは、各ヘッドに1つの非ゼロのコンテキストポイントが割り当てられますが、それでも良い精度が保たれます。 これにより、畳み込みを効果的にインデキシングに変換することができます。

可変サイズの入力。

DAFT-convは完全な畳み込みであるため、学習時とは異なる入力サイズを扱うことができます。 我々は、DAFT-convモデル(表2の最後から2番目の行、クロップサイズ224でトレーニング)を、より大きなクロップサイズ384でテストしました。 その結果,精度が81.6に向上しました(オリジナルは80.8). このことから、DAFT-convはVisionタスクでよく見られる事前トレーニングの微調整ワークフローに適していると言えます。

Transformersとの互換性。

DAFTはMHAを直接近似するように設計されたものではありませんが、どちらのモデルも値ベクトルが学習された非負の重み付けで集約されるという点でかなりの類似性があります。 我々は、あるモデルで学習された表現が別のモデルに移行できるという仮説を立てた。 この仮説を検証するために,事前に学習したクロップサイズ384の「DeiT base」モデルを入手しました。 次に,DAFT-convの重みをDeiTモデルの重みで初期化して,DAFT-convを学習します。 バッチサイズを64とし,100エポックの学習を行いました. コントロールとして,ランダムに初期化されたDAFT-convも同じ回数のエポックで学習させました。 その結果を表3に示します。

興味深いことに、DAFT-convを微調整したバージョンは、ランダムに初期化したバージョンよりも有意に高い精度を達成していることがわかります。 また、結果として得られたモデルは、オリジナルのDeiTモデルよりも精度、速度、メモリ効率が高くなっています。 グローバルな接続性 DAFT-convは、ローカルカーネルサイズに関わらず、グローバルな接続性を維持します。 これは、スパースやローカルアテンションの作品とは異なる点です。 この設計の利点を見るために、我々はDAFT-convの縮退変形を学習しました。 ここでは、ローカルウィンドウの外側のwt,t0に-∞の値を割り当てるように式7を修正します(指数化の後に重みをゼロにします)。

次に、カーネルサイズを変化させることで、一連のDAFTconvモデルを学習し、その結果を図2に示します。 DAFT-convの性能は、カーネルサイズの選択に対して非常にロバストであることがわかります。 グローバルな接続性がない場合、小さなカーネルサイズは深刻なパフォーマンスの低下につながります。

5.2 画像自己回帰モデル化

次の実験では、負の対数尤度(NLL)を最小化することによる画像自己回帰モデル化の問題を検討する。 (Parmar et al., 2018)と同様に、RGB画像を長さH×W×3のシーケンスとして表現し、H、Wはそれぞれ高さと幅を表します。 各サブピクセルは、256通りの離散変数として表されます。 ベンチマーキングデータセットとしてCIFAR10を使用しています。 参考となるTransformerの設計は、12層、512次元、4ヘッドの設計を踏襲した(Chen et al.)の設計をほぼ踏襲しています。 AdamW(Loshchilov & Hutter, 2019)を使用し、(Vaswani et al., 2017)のように標準的なウォームアップ学習率のスケジュールに従います。

我々は、1×10-3の初期学習率、すべての線形変換の重みに適用される0.1の重み減衰、および0.1のドロップアウト率を使用します。 シンプルなデータ増強を採用しています。 学習の際には,まず各画像をランダムに水平に反転させ,その後,すべてのサブピクセルに[-10, 10]の範囲の値を加算または減算し,結果のピクセル値を[0, 255]にクリップします. また,クロス・エントロピー損失を用い,デフォルトのバッチサイズを32とし,100回の学習エポックを行いました.

2種類のDAFTモデルを学習します。 DAFT-full(隠れ次元=256)と、DAFT-conv(1dカーネルサイズ=256)です。このケースでは、1d DAFT-convはピクセルの真の構造(2d)を利用していないことに注意してください。 この場合、1d DAFT-convは(2dである)ピクセルの真の構造を利用していません。

最新の技術との比較

CIFAR10は、画像の自己回帰モデリングのための混雑したベンチマークであり、表4に示すように、いくつかの競合するベースラインと比較します。 CIFAR10では、展開されたシーケンスの長さが3072であり、妥当なサイズの完全なTransformerを学習するにはすでに法外な長さであることに注意してください。 もう1つのベースラインはImage Transformer (Parmar et al., 2018)で、これは256のサイズのlocal2dウィンドウに注目を制限しています。 また、Synthesizer (Tay et al., 2020a) や Reformer (Kitaev et al., 2020) とも比較しています。 表4から、DAFT-fullはベースラインのTransformerと同様の性能を示し、Image Transformer、Synthesizer、Reformerを上回っていることがわかります。 興味深いことに,DAFT-convは,画素の2次元局所構造を利用していないにもかかわらず,DAFT-fullよりもさらに高い性能を示しています。 効率面では、DAFTの両バージョンとも、標準的なTransformerやその他のベースラインよりも高速で、メモリ消費量は半分程度です。

image

image

5.3 言語モデリング

自動回帰モデリングのベンチマークとして有名なEnwik8(Mahoney, 2011)を用いて、文字レベルの言語モデリングにDAFTを適用しました。 我々は、(Dai et al., 2019)にあるような標準的な前処理手順とトレーニング/バリデーション/テストの分割に従います。 ベースとなるTransformerのリファレンスは、2048のフィードフォワード次元を持つ12層512次元8ヘッドのアーキテクチャです。 最初の実験セットでは、シーケンス長を1024としています。 学習プロトコルは、重みの減衰を0.5に増やし、バッチサイズ128で100エポック分の学習を行うこと以外は、ほぼ前回の実験と同じです。 DAFT-fullとDAFT-convをカーネルサイズ64で評価しました。 また、いくつかの効率的なTransformerベースライン、すなわちReformer (Kitaev et al., 2020), Synthesizer (Tay et al., 2020a) , Linear Transformer (Katharopoulos et al., 2020), Performer (Choromanski et al., 2020)と比較しました。 表5によると,DAFTモデルはいずれも学習ビット数(bpc)が最も少なく,これはモデルの能力が高いことを示す指標となっています。 テスト性能は、基本的なTransformerよりも若干劣るものの、他のTransformerのバリエーションよりも優れています。

image