shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation #219

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2009.13818] A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation

著者・所属機関

Dinghan Shen, Mingzhi Zheng, Yelong Shen, Yanru Qu, Weizhu Chen

投稿日時(YYYY-MM-DD)

2020-09-29

1. どんなもの?

自然言語処理で使用されているBERTなどの巨大な言語モデルをFine-Tuningする際に,学習コストの高い敵対的学習や,余分なモデルやデータセットが要求される逆翻訳などの手法ではなく,単純に入力に使用されるEmbedding層の重みの一部を0に変換するCuroffを提案している.

外部モデルや余計な計算コストを必要としないが,これらの手法と同等程度の精度改善を達成している.

2. 先行研究と比べてどこがすごいの?

自然言語処理の分野では巨大のモデルを使用することで高い精度を発揮できるようになっている。先行研究では、事前学習モデルから得られる分散表現は、下流タスクのデータでFine-Tuningをした場合、汎化性能が減少することが経験的にわかっている。

この問題を解決するために、Fine-Tuningを行っている際に敵対的な損失関数を使用することで、分散表現に対して正則化を行うことが提案されている。具体的には、単語の分散表現に対してラベルを保持したままノイズを導入し、こうしたノイズに影響しないように予測できるようにすることである。

しかし、この手法では余分に逆伝搬が必要とされ、計算コストが増大することがわかっている。

本研究では、単純で効果的なData Augmentationの手法を提案している。具体的には、学習を行う際に単語を分散表現に変換するEmbedding層が最初に存在するが、このEmbedding層の一部を0に置き換える手法である。

増強したサンプルをより予測が難しいものにするために、入力される文から連続したいくつかの単語を削除する手法を用いる。これで文章から様々なSemantic Featureを抽出できるようになる。

あるデータを複数のパターンに増強したあと、これらのサンプル間の関係性を補足できるように、JS Divergenceを利用した正則化項を提案している。

3. 技術や手法の"キモ"はどこにある?

3.1 Motivation

本手法は、巨大な事前学習済みの言語モデルをFine-Tuningする場合、データ増強により有用なSemantic Featureを抽出できる能力を持たせることができるという仮説に基づいている

データ増強の関数をfとして場合、データxを増強したf(x)は、もとのデータxとラベルが変化しないようにする必要がある。またf(x)は十分多様性があるものにしなければならない。

今までの研究ではこの関数fの選択に、ガウスノイズや敵対的学習、逆翻訳などが提案されている。

既存の手法には上記のような弱点が存在しているため、本研究では計算コストがかからず、追加の言語モデルや外部データに依存しない形の手法を提案している。

そこでMulti-View Learningの視点を取り入れ、あるサンプルxから得られる増強済みのサンプルx1とx2から、それぞれのサンプルに対する予測をp1とp2とした場合、以下の条件を保持するように学習を行う。

image

3.2 Constructing Partial Views

Transformer-basedなモデルでは自己注意機構を使用しており、各出力のUnitはすべての入力Tokenにアクセスすることができるため、入力となるEmbedding層に対して多視点からの観点を導入することでモデルに依存しない手法にすることができる。

本研究では、入力されたある文章(L個のTokenで構成されており、各Tokenの次元数をdとしている)に対して、Embedding層の重みに関して単語や特徴量を一部0に変換するCutoffを提案している。

image

BERT系統のモデルでは入力値に使用されるEmbedding層には、Token-Embedding層やPositional-Embedding層、Segment-Embedding層などで構成されている。

3.3 Incorporating Augmented Samples

上記の手法により同じ文章xから、N個のCutoffしたサンプルを得ることができる。これらのサンプルは、すべて同じ意味情報を含んでいると仮定しているため、サンプル間の意味情報の一貫性が保たれるように以下のような目的関数を定義している。

image

ここで非対称なKL Divergenceを使用するとN個のサンプルに対して、2^{2N+1}個の計算が必要になってしまうため、以下のようにJS Divergenceを採用している。

image

4. どうやって有効だと検証した?

image

image

image

image

5. 議論はあるか?