Open e4exp opened 3 years ago
1 INTRODUCTION
Transformerアーキテクチャ(Vaswaniら、2017)は、自然言語処理で広く使用されており、多くのタスクで最先端の結果をもたらします。 このような結果を得るために、研究者はこれまで以上に大きなTransformerモデルを学習することに頼ってきました。 Shazeer et al., 2018)で報告された最大の構成では、パラメータ数が1層あたり0.5Bを超え、(Al-Rfou et al., 2018)では層数が64まで増えています。 変圧器モデルは、ますます長いシーケンスにも使用されます。 1つの例で最大11,000トークンのテキストが(Liu et al., 2018)で処理され、音楽(Huang et al., 2018)や画像(Parmar et al., 2018)のような他のモダリティを処理する際には、さらに長いシーケンスが当たり前になります。 これらの大規模な長文シーケンスモデルは素晴らしい結果をもたらしますが、この傾向がNLP研究を壊しているという意見もあるほど、リソースを圧迫しています1。 多くの大規模なTransformerモデルは、現実的には大規模な産業研究所でしかトレーニングできず、モデル並列処理でトレーニングされたこのようなモデルは、1つのトレーニングステップでもメモリ要件がマルチアクセラレータのハードウェアセットアップを必要とするため、1つのGPUで微調整することさえできません。 大規模なTransformerモデルは、基本的にこのような膨大なリソースを必要とするのでしょうか、それとも単に非効率なのでしょうか。 報告されている最大のTransformerレイヤーで使用されている0.5B個のパラメータは、2GBのメモリを占めています。 埋め込みサイズ1024、バッチサイズ8の64Kトークンの起動には、64K×1K×8=0.5B個の浮動小数点数が必要で、さらに2GBのメモリが必要になります。 もしもメモリ使用量がレイヤーごとにしかないのであれば、64Kの長さのシーケンスであっても、1台のアクセラレータで大規模なTransformerを簡単に動かすことができるはずです。 さらに、BERTの学習に使用したコーパス全体の保存には17GBしか必要ありません。ではなぜ、これらのモデルを単一のマシンで微調整することさえできないのでしょうか。
上記の見積もりには、レイヤーごとのメモリと入力アクティベーションのコストのみが含まれており、Transformerのメモリ使用の主な原因である以下の点は考慮されていません。
これらの問題を以下のような手法で解決するReformerモデルを紹介します。
これらの技術を研究し、標準的なTransformerと比較して、学習プロセスへの影響はごくわずかであることを示しました。 活性化の分割は、実際には実装に影響するだけで、Transformerで使用されている層と数値的には同じです。 標準的な残差の代わりに可逆的な残差を適用すると、モデルは変化しますが、実験したすべての構成において、学習に対する影響はごくわずかです。 最後に、注目のロカリティセンシティブハッシュは、より大きな変化であり、使用する同時ハッシュの数に応じて、学習ダイナミクスに影響を与えます。 このパラメータを検討し、効率的に使用でき、かつフルアテンションに非常に近い結果が得られる値を見つけました。 合成タスク、長さ64Kの配列を持つテキストタスク(enwik8)、長さ12Kの配列を持つ画像生成タスク(imagenet-64 generation)で実験を行った。 どちらの場合も、ReformerはTransformerのフル機能で得られた結果と一致しますが、特にテキストタスクでははるかに高速に動作し、メモリ効率も桁違いに良いことを示しています。
2 LOCALITY-SENSITIVE HASHING ATTENTION
ドットプロダクトアテンション Transformerで使用される標準的なアテンションは、スケーリングされたドットプロダクトアテンションです(Vaswani et al., 2017)。 入力は、次元dkのクエリとキー、および次元dvの値で構成されています。 クエリとすべてのキーのドットプロダクトが計算され、√ dkでスケーリングされ、ソフトマックス関数が適用されて値の重みが得られます。 実際には、一連のクエリに対する注目関数が同時に計算され、行列Qにまとめられます。 キーと値も行列KとVにまとめられていると仮定すると、出力の行列は次のように定義されます。
マルチヘッドアテンション。 Transformerでは、dモデル次元のキー、値、クエリに対して1つのアテンション機能を実行するのではなく、クエリ、キー、値をそれぞれdk、dk、dv次元に学習した異なる線形投影でh回線形投影します。 これらの投影されたバージョンのクエリ、キー、値のそれぞれにアテンションが並行して適用され、dv次元の出力値が得られます。 これらは連結され、再び投影され、最終的な値が得られます。 このメカニズムはマルチヘッドアテンションと呼ばれています。
メモリ効率の良い注意 注意メカニズムのメモリ使用量を計算するために、式1の注意計算に注目してみよう。 Q、K、Vはすべて[batch size, length, dmodel]の形をしていると仮定しよう。 ここで問題となるのは、QK^Tという項で、これは[batch size, length, length]という形をしている。 この場合、バッチサイズが1であっても、64K×64Kの行列となり、32ビット浮動小数点数では16GBのメモリを必要とします。 これは現実的ではなく、長い配列に対するTransformerの使用を妨げていました。 しかし、QK^T行列はメモリ上で完全に実体化されている必要はないことに注意することが重要です。 アテンションは、実際に各クエリq_iに対して個別に計算することができ、メモリ内で一度だけsoftmax( q_i K^T / √d_k )Vを計算し、グラデーションが必要な場合はバックワードパスで再計算します。 このアテンションの計算方法は効率が悪いかもしれませんが、長さに比例したメモリしか使用しません。 このメモリ効率の良いアテンションの実装を用いて、実験セクションで紹介するフルアテンションのベースラインを実行します。
Q、K、Vはどこから来るのか? 上述のマルチヘッドアテンションは、キー、クエリ、値を操作しますが、通常は[バッチサイズ、長さ、dmodel]という形状の活性化Aの単一テンソルしか与えられません。 AからQ、K、Vを構築するために、Transformerは3つの異なる線形層を使用し、Aを異なるパラメータでQ、K、Vに投影します。 LSHに注目したモデルでは、クエリとキー(QとK)を同一にしたいと考えています。 これは、AからQとKへの投影に同じ線形層を使用し、Vには別の線形層を使用することで簡単に実現できます。 このような動作をするモデルをshared-QK Transformerと呼びます。 実験的なセクション5で示すように、鍵Kの長さを追加で正規化しても、QKの共有はTransformerの性能に影響を与えないことが判明しました。
ハッシングの注意 LSHのアテンションでは、[batch size, length, dmodel]という形の2つのテンソル、Q=KとVから始めます。 マルチヘッドのメカニズムはそのままにして、式1からのアテンションの計算に注目する。 すでに述べたように、主な問題は、[バッチサイズ、長さ、長さ]の形を持つQK^Tという項です。 しかし、実際にはsoftmax(QK^T )にしか興味がないことに注意してください。 softmaxは最大の要素に支配されるので、各クエリq_iに対して、Kの中でq_iに最も近いキーに注目すればよいのです。 例えば,Kが64Kの長さであれば,各q_iに対して,例えば32個または64個の最も近い鍵の小さなサブセットだけを考慮すればよいのです. これは非常に効率的な方法ですが、ではどうやって鍵の中から最も近いものを見つけるのでしょうか?
ロカリティセンシティブハッシング 高次元空間において最近傍を素早く見つけるという問題は、Locality-sensitive Hashing (LSH)によって解決できる。 各ベクトルxにハッシュh(x)を割り当てるハッシュ方式は、近くのベクトルが高確率で同じハッシュを取得し、遠くのベクトルは取得しない場合、ロカリティセンシティブ(locality-sensitive)と呼ばれます。 今回のケースでは、実際には、近くのベクトルが高確率で同じハッシュを得ることと、ハッシュバケットが高確率で同程度のサイズであることを要求しているだけである。 これを実現するには、次のようにランダムな投影を行います(図1参照)。 b個のハッシュを得るために、まず、サイズ[d_k, b/2]のランダムな行列Rを固定する。 そして、h(x) = arg max([xR; -xR]) (ここで、[u; v]は2つのベクトルの連結を表す)と定義します。 この方法は、既知のLSHスキーム(Andoni et al., 2015)であり、実装が容易で、ベクトルのバッチにも適用できます。
LSHアテンション 我々のLSHスキームとハッシュ化されたアテンションの一般的な考え方を知った上で、本稿で使用するLSHアテンションを正式に定義する。 まず、通常のアテンションの方程式である(1)を、一度に1つの問い合わせ位置iに対して書き換えます。
位置iのクエリが属するセットを表すためにP_iという表記を導入し、パーティション関数(つまりソフトマックスの正規化項)を表すためにzという表記を導入します。 また、わかりやすくするために、√d_kによるスケーリングを省略しています。 バッチングのために、我々は通常、より大きなセットP_i~ = {0, 1, ... ... , l}⊇P_iに対してアテンションを実行し、P_i に含まれない要素をマスクする。
ここで、LSHアテンションについて考えてみましょう。 これは、単一のハッシュバケット内でのアテンションのみを許可することで、問い合わせ位置iがアテンションできるターゲットアイテムのセットP_iを制限するという観点から考えることができます。
図2(a-b)は、フルアテンションとハッシュ化されたバリアントの比較を模式的に示しています。 (a)では、フルアテンションのアテンションマトリクスは一般的にスパースであるが、計算ではこのスパース性を利用していないことが描かれている。 (b)では、クエリとキーがハッシュのバケットに応じてソートされています。 似たようなアイテムは高い確率で同じバケットに入るので、各バケット内での注意のみを許可することで、完全な注意パターンを近似することができます。
この方式のハッシュバケットはサイズが不揃いになりがちで、バケット間のバッチ処理が困難になります。 さらに、バケット内のクエリの数とキーの数が不揃いになる可能性があり、実際、バケットには多くのクエリが含まれているが、キーは含まれていないということもあり得ます。 これらの問題を軽減するために、まず k_j = q_j / || q_j || とすることで、h(k_j ) = h(q_j ) を確保します。 次に、クエリをバケット番号でソートし、各バケット内では配列位置でソートします。 これにより、ソート後にi →s_iとなる順列が定義されます。 ソートされた注目マトリクスでは,同じバケットからのペアは対角線付近に集まります(図2cに描かれています). m個の連続したクエリのチャンク(ソート後)がお互いにアテンションし、1つ後ろのチャンクにアテンションするというバッチングアプローチをとることができます(図2d)。 先ほどの表記法に従うと、これは設定に相当します。
max_i|P_i|<mであれば、P_i⊆Pi~となります。 実際には,m = 2l / n{buckets} とする(l は配列の長さである). 平均的なバケツの大きさは l / n_{buckets} であり、バケツがその2倍の大きさになる確率は十分に低いと仮定する。 LSHの注目度の全体的なプロセスは、図2にまとめられている。
マルチラウンドLSHの注目点 ハッシュ処理では、似たようなアイテムが異なるバケットに入る確率が常にわずかに存在する。 この確率は、n_{rounds} 個の異なるハッシュ関数 {h^(1), h^(2) , ...} を用いて複数ラウンドのハッシュ処理を行うことで低減できる。
多ラウンドの場合、基本的にはLSHの注目度をnラウンドずつ並行して行うことになります。 この手順の詳細は付録Aに記載されています。
共有QK注意のための因果的マスキング。 Transformerデコーダでは、マスキング(式3のm(j,P_i)で示される)を使用して、位置が未来に向かってアテンドするのを防ぎます。 LSHアテンションでマスキングを実装するには、すべてのクエリ/キー・ベクトルに位置インデックスを関連付け、クエリ/キー・ベクトルのソートに使用されたのと同じ並べ替えを使用して位置インデックスを再配置し、比較演算を使用してマスクを計算します。
未来への注意は許可されていませんが、Transformerの典型的な実装では、ある位置が自分自身に注意を向けることができます。 このような動作は、共有QK方式では望ましくありません。 なぜなら、クエリ・ベクトルとそれ自身とのドット・プロダクトは、ほとんどの場合、クエリ・ベクトルと他の位置のベクトルとのドット・プロダクトよりも大きくなるからです。 そのため,トークンが他に有効な注目対象を持たない場合(例えば,シーケンスの最初のトークン)を除き,トークンが自分自身に注目することを禁止するようにマスキングを修正しました.
2.1 合成タスクでの分析
LSH注目の性能を検証し、その動作を研究するために、次の合成タスクから始めます: シンボルのシーケンスを複製する。 このタスクでは、トレーニングとテストの各例は、0w0w(w∈{1, ... ... , N})の形をしています。, N} ∗ ここで,w∈{1, ... , N}は,1からNまでの記号列である(実験ではN=127とした). 以下に,長さ3の単語wの例を示す.
LSHの注意を調べるために、各wが長さ511である(つまり、入力0w0w全体が長さ1024である)上記の形式の例で言語モデルを訓練します。 これは言語モデルの課題なので、前の記号がすべて与えられれば、常に次の記号を予測しますが、入力の後半の位置、つまり実際に予測できる位置のみを考慮して、損失と精度をマスクします。
上記の課題は、1層のTransformerモデルで完璧に(精度100%、損失0)解決できます。 しかし、この課題は非局所的な注意の検索を必要とするため、限られたスパンでの疎な注意に依存するモデルでは解決できないことに注意してください。 NLPで使用されるモデルと同様に、簡単かつ高速に学習できるように、d_model = d_ff = 256、4ヘッドの1層Transformerを使用しています。 フルアテンション、LSHアテンション、n_rounds = 1、n_rounds = 2、n_rounds = 4の4種類の設定で、150Kステップの学習を行いました。
表2にまとめた結果から、フルアテンションで学習したモデルは、LSHアテンションでもすぐに使用できるが、多少精度が落ちることがわかった。 また、4つのハッシュで学習したモデルは、LSHアテンションでゼロから学習した場合、ほぼ完璧な精度を達成しています。 興味深いのは、8個のハッシュで評価したときに精度が完璧になることです。 ハッシュ数が2または1の場合は、精度が低下します。 少ないハッシュで学習したモデルは結果が悪くなりますが、1つのハッシュで学習したモデルでも、8つのハッシュで評価するとほぼ完璧になります。
3 可逆変圧器
上のセクションで示したように、近似が受け入れられるならば、注目の複雑さは長さが2乗のものから線形のものに減らすことができます。 しかし、表1から明らかなように、各フィールドはb ・ n_h ・ lの項で始まります。 b ・ n_h ・ l ・ d_k、あるいは代わりにb ・ l ・ d_modelのコストを避けることはできません。 実際、各層の前の活性化は既に b ・ l ・ d_model というサイズであるため、n_l 層のモデル全体のメモリ使用量は少なくとも b ・ l ・ d_model ・ n_l となる。 さらに悪いことに、Transformerのフィードフォワード層の内部では、これはb ・ l ・ d_ff ・ n_lにまで上がります。 大型のTransformerでは、d_ff = 4K、n_l = 16とするのが一般的で、l = 64Kとすると、これもまた16GBものメモリを使用することになります。
このセクションでは、まずリバーシブル・レイヤーを使って項のn_l部分を処理し、次にチャンキングによってd_ff問題を処理できることを示すことで、このコストを削減する方法を示す。 これらの各アプローチのメモリと時間の複雑さに対する効果を表3にまとめた。
RevNets(レブネット)。 可逆的残差ネットワークは、Gomezら(2017)によって導入され、画像分類のためにResNetsを置き換えることができることが示された。 主なアイデアは、モデルのパラメータのみを使用して、任意の層の活性化を次の層の活性化から復元できるようにすることです。 バックワードパスで使用するために中間値をチェックポイントする必要はなく、ネットワークの出力からその入力へとバックプロパゲーションが進むにつれて、層を1つずつ反転させることができます。 通常の残差層は、単一の入力で動作し、単一の出力を生成する関数x →yを実行し、y = x + F(x)という形式をとるのに対し、可逆層は入力/出力のペアで動作します。(x1, x2) → (y1, y2)という入力と出力のペアで動作し、次のような式に従います。
レイヤーは、残差を(足すのではなく)引くことで反転させることができます。
リバーシブル・トランスフォーマー Revnetブロック内の注目層とフィードフォワード層を組み合わせることで、RevNetのアイデアをTransformerに適用します。 上の表記では、Fがアテンション層になり、Gがフィードフォワード層になります。 なお、Layer Normalization(Ba et al., 2016)は残差ブロックの内部に移動しています。
可逆Transformerは、各層に活性化を保存する必要がないため、n_l項が不要になります。 第5節では、同じ数のパラメータを使用した場合に、通常のTransformerと同じ性能を発揮することを示します。 これを実現するには、x1とx2の両方のサイズをd_modelにする必要があります。
チャンキング。 可逆性はn_l項をカバーしていますが、厚みのある層はまだ多くのメモリを使用することができます。 特にフィードフォワード層では、次元数d_ff = 4K以上の中間ベクトルを使用することができます。 しかし、フィードフォワード層での計算は、シーケンス内の位置によらず完全に独立しているため、計算をc個のチャンクに分割することができます。
このレイヤーは通常、すべてのポジションに対する演算を並行して行うことでバッチ処理を行いますが、1つのチャンクに対して一度に演算することでメモリを削減することができます。 (8)の逆計算とバックワードパスもチャンク化されます。 フィードフォワード層に加えて、語彙数の多いモデル(d_model以上の単語タイプ)では、出力の対数確率もチャンク化し、シーケンスの各セクションの損失を一度に計算します。
チャンキング、ラージバッチ、パラメータの再利用。 チャンキングとリバーシブルレイヤーにより、ネットワーク全体の活性化に使用するメモリはレイヤーの数に依存しません。 しかし、パラメータはレイヤー数に応じて増加するため、同じことは言えません。 しかし、この問題は、層のパラメータを、その層が計算していないときに、CPUメモリと入れ替えることで解決します。 標準的なトランスフォーマーでは、CPUへのメモリ転送が遅いため、この方法は非効率的です。 しかし、Reformerでは、バッチサイズに長さを掛けたものがはるかに大きいため、パラメータで行われる計算量で、その転送コストを償却することができます。
5 実験
本節では,上述の技術を実証する実験結果を示す。 どのような組み合わせが性能に影響を与えるかを明らかにするために、1つ1つの技術を分析します。 まず、可逆的なレイヤーと共有クエリキー空間が性能に影響を与えないことを示し、次にハッシュの注目度を分析し、最後にReformerのフルモデルを分析します。 実験はimagenet64とenwik8-64Kタスクで行いました。 enwik8はenwik8の変形版で、2^16 = 64Kトークンのサブシーケンスにチャンクされています。 メモリ使用量が多く、完全なO(l^2)注意を行う通常のTransformerと比較して扱いやすいように、3層モデルをアブレーショ ンに使用しています。 すべての実験では、d_model = 1024, d_ff = 4096, n_heads = 8, 合計バッチサイズは8シーケンスとしました。 これらのモデルのトレーニングには、Adafactorオプティマイザー(Shazeer & Stern, 2018)を使用しました。 また、Vaswaniら(2017)のハイパーパラメータに従って、WMT 2014英独翻訳タスクで評価しました。 すべての実験の学習は、8つのデバイス(8つのGPUまたは8つのTPU v3コア)で並列化されました。 モデルを学習するためのコードは公開されています2。
QK共有の効果
まず、通常のTransformerモデルにおけるShared-QK attentionの効果を検討します。 Shared-QK attentionは、k_j = q_j / ||q_j||と設定し、トークンが自分自身に注目することを防ぎます(他のコンテキストが利用できない場合を除く)。 図3の左部分では、通常の注意とShared-QK注意の両方について、perplexity曲線をプロットしています。 共有QK空間は、通常のアテンションよりもパフォーマンスが悪くなることはありません。 実際、enwik8では、わずかに速く学習できるように見えます。 つまり、shared-QKアテンションに切り替えても、精度が犠牲になることはないのです。
リバーシブルレイヤーの効果。
図3の右の2つのプロットでは、Vaswaniら(2017)による通常のTransformerと、セクション3で説明した可逆的なTransformerを比較しています。 この2つのモデルはパラメータ数が同一であり、学習曲線も同様にほぼ同じに見える。 これらの結果は、可逆的なTransformerにおけるメモリの節約が、精度を犠牲にしないことを示しています。
機械翻訳におけるリバーシブル・レイヤー
また、英語からドイツ語への機械翻訳のためのエンコーダ・デコーダのTransformerモデルのコンテキストで可逆層を評価します。 まず、Transformerベースのアーキテクチャで、エンコーダとデコーダの両方を完全に可逆的にすることから始め、得られたモデルが100Kステップで学習した場合に、Vaswani et al. また、より多くのステップ数で、より大きなモデルを使ったトレーニングを評価します。 リフォーマーモデルは非常にメモリ効率が良いので、後者の2つの実験では、モデル全体で埋め込みと出力投影の重み行列を共有することでメモリを節約する必要はありません。 結果を表4に示す。 この実験では、例文が単文であり、単文が比較的短い傾向にあるため、LSHアテンションを適用していない。 我々の典型的なLSHアテンションの構成では、ハッシュとソートの後に128トークンのチャンクを使用するが、WMT14テストセットの例はすべて128トークンよりも短い。
大規模なリフォーマーモデル
リフォーマーが実際に大規模なモデルをシングルコアに収めることができ、長いシーケンスで高速にトレーニングできることを検証するために、enwik8とimagenet64で最大20レイヤーの大規模なリフォーマーをトレーニングした。 図5に見られるように、これらのモデルはメモリに収まり、学習することができました。 この場合、Transformerのベースラインは遅すぎてメモリを消費するため、訓練することができませんでしたが、層の数を増やすことで明らかな改善が見られます。 enwik8の12層モデルを2万ステップ、ドロップアウト率0.1で学習させたところ、テストセットで1.19ビット/dimを達成しました。 また、12層のReformerモデルをさらにチューニングして長時間学習させたところ、enwiki8のテストセットで1.05 bits/dimを達成しました。
6 おわりに
Reformerは、Transformerのモデリング能力と、レイヤー数が多いモデルでも長いシーケンスを効率的に実行でき、メモリ使用量が少ないアーキテクチャを兼ね備えています。 これにより、大規模で豊富なパラメータを持つTransformerモデルがより広く普及し、利用しやすくなると考えています。 また、長いシーケンスを扱えるようになったことで、多くの生成タスクにReformerを使用する道が開かれました。 Reformerは、非常に長いコヒーレントなテキストを生成するだけでなく、時系列予測、音楽、画像、ビデオの生成など、他のドメインにもTransformerモデルの力をもたらすことができます。
大規模なトランスモデルは日常的に多くのタスクで最先端の結果を得ていますが、これらのモデルのトレーニングは、特に長いシーケンスでは法外なコストがかかることがあります。 我々はトランスフォーマーの効率を向上させるために2つの手法を紹介する。 1つは、ドット積注意をロカリティ依存のハッシュを使用するものに置き換え、その複雑さをO(L2)からO(LlogL)(Lはシーケンスの長さ)に変化させます。 さらに、標準的な残差の代わりに可逆的な残差層を使用することで、N回の学習プロセスではなく、1回だけの学習プロセスで活性化を保存することができます。 結果として得られたモデルであるリフォーマーは、トランスフォーマーモデルと同等の性能を発揮する一方で、よりメモリ効率が高く、長いシーケンスでははるかに高速である。