yoheikikuta / paper-reading

Notes about papers I read (in Japanese)
156 stars 4 forks source link

[1911.02855] Dice Loss for Data-imbalanced NLP Tasks [paper-reading] #52

Open yoheikikuta opened 4 years ago

yoheikikuta commented 4 years ago

論文リンク

https://arxiv.org/abs/1911.02855

公開日(yyyy/mm/dd)

2019/11/07

概要

単純な cross entropy では、 accuracy のみを見て False Positive, False Negative 両方をケアできてない点、簡単なデータが大量にあったときにそれらの寄与が dominant になって見分けたい難しいケースを見分けづらい点、が問題となり得る。 前者は F1 score を拡張した Dice loss を用い、後者は Focal loss のアイデアを拝借して (1 - p) の weight を掛けることで解決を図り、実験して試したら良い結果でした、という論文。 imbalance と言っても実験で試していたのが 1:4 くらいの度合いで、現実はもっと激しいのでそういうケースに使えるアイデアがないかと思って読んだけどそういう話ではなかった。

yoheikikuta commented 4 years ago

なんか面白い論文ないかな〜と思って ACL 2020 の accepted paper 眺めてたら見つけたものの一つ。 自分が今後取り扱っていく問題もかなり data-imbalanced なものになりそうなので、使えそうならいいなと思って読んでみることにする。

Dice loss っていうと semantic segmentation とかで使う Dice loss だよね?という感じがするが、画像の場合は予測領域と正解領域の intersection / union とするところを NLP タスクに応用するという話なのだろうか。軽く眺めるとそうっぽいが imbalance を解決するために何かしらの工夫はしていそう。

とりあえず読んでいくことにする。

yoheikikuta commented 4 years ago

data imbalance な状況で生じるのは以下の二つの問題で、そのどちらも現在は適切には解決されていないという認識。

yoheikikuta commented 4 years ago

最初の問題を解決するためには False Positive, False Negative 両方に等しく寄与する loss function が必要で、そのために Dice loss や Tversky index を使うというのが、この論文の一つ目のアイデア。

二つ目の問題はこれだけでは解決しないので、Focal loss に inspire されて学習データに対して (1-p) の重みをつけて学習するというのが、この論文の二つ目のアイデア。

なんかこれだけ読むと CV のアイデアを NLP に持ってきてみました、という感じに見えなくもないな...

ちなみに Focal loss に関しては以前この paper-reading で読んでいた: https://github.com/yoheikikuta/paper-reading/issues/19 昔読んだものがたまたま出てくる程度にはこの paper-reading も情報が増えてきたんだなぁと感慨も一入である(本当か?)。

yoheikikuta commented 4 years ago

提案手法な具体的な内容に入る前に notation などの確認。 まず、この論文では binary classification の場合を例にとって説明しているが、提案手法は multi-class classification に拡張可能である。表式は multi-class への拡張を意識している形になっている。

みんな大好き cross entropy は以下。 $ i $ の添字はデータの添字で、最大で $ N $ まで走る。$ j $ の添字がクラスの添字で、今回は binary classification を念頭に置いてるので 0,1 のどちらかになっている。multi-class の場合はここを multi-class に対応させればよい。$ y{ij} $ が正解ラベルで $ p{ij} $ が予測確率である。

cross entropy

式から明らかなのは、それぞれのデータは等しく loss に寄与するということだ。 なので imbalance なデータだと大量にデータがあるクラスの影響が大きくなりすぎたりする。それを防ぐには resampling したり、もしくは以下のようにデータに重み α ∈ [0, 1] を付与する(どちらもデータ分布を変えるという意味では同様のアプローチである)。

weighted cross entropy

ちなみに multi-class の場合にこの α をどう決めるのかは結構トリッキーで、selection bias がかかるのであまり使われてないとのこと。

yoheikikuta commented 4 years ago

次に提案手法に組み込むことになる Dice coefficient と Tversky index の説明。 どちらも初出はめちゃ昔で、前者は 1945 で後者は 1977 に提案っぽい。

Dice coefficient は、二つの集合 A,B が与えられときに以下で定義される。 binary classification の枠組みで考えると、A があるモデルで予測された全ての positive instance から成る集合で、B はデータの中の全ての golden positive instance (golden とか言ってるけどつまり正解データ)である。

DSC

モデルが positive と予測したデータが全て正しく正解ラベルになるなら 1 になるし、逆に本当は全て negative ならば 0 になる。 binary classification の TP, FN, FP の言葉で書くと 2TP / (2TP + FN + FP) = F1 となることがわかり、これは F1 スコアと同等のものになる。

これを個別の instance $ x_i $ に対して書き換えると以下のように書ける。

DSC(x)

これも 1 を添字にしてその添字の sum を取ることで multi-class に書き換えることは可能だが、(binary classification の場合と同じく) 寄与するのは y = 1 となる instance のみで、それ以外の instance の寄与はモデルの予測に関わらずなくなってしまう。 これを回避する一つの簡単な方法は、分母と分子の両方に以下のような factor を導入することである。

DSC(x) with γ

単純ではあるがこれで $ y_{i1} = 0 $ の場合でも非ゼロの寄与が生まれる。 Milletari et al. (2016) ではこれを以下のように分母を二乗にした形で dice loss を定義している(その方が学習がしやすい。ちなみにラベルは 0 or 1 なので二乗しても変わらないけど式の見た目の対称性のために二乗)。

DL

1 parameter γ で negative instance の寄与も取り込めるので、シンプルな割にはなかなか悪くない感じはする。

yoheikikuta commented 4 years ago

Tversky index (TI) は Dice coefficient を拡張して A, B を非対称に扱うもので、以下で定義される。

TI

これは $ α = β = 1/2 $ のときに Dice coefficient を再現することは簡単に分かる。 これも loss の形で書いておくと以下のようになる。

TL

パラメタが増えた分、A, B の取り扱いをよりコントールできるようになった。 でもこれくらい増えるとチューニングの余地がありすぎてイマイチではある。

yoheikikuta commented 4 years ago

上で出てきたような DL, TL を使えば、これらが F1 score の拡張になっていることからも分かるように、False Positive, False Negative 両方の影響を受けるものであるから冒頭の問題点の一つ目はクリアできる。

一方で easy example が多い場合にそれの影響が dominant になってしまうという問題にはアプローチできていない(冒頭では物体検出でおなじみの easy negative の問題だけに言及していたが、ここでは pos/neg 両方の easy example について考える)。 具体的にはめちゃくちゃ簡単な postive data で p = 1 となるデータが大量にあった場合、これらの loss への影響が dominant になってしまう。p = 0 となる簡単な negative data に関しては寄与は小さい。

これの解決に Focal loss のアイデアを使う。

p が 1 に近い場合に loss への影響が小さくするようにすればいいのだから (1 - p) のような weight をかけてやるのが単純なソリューションであろう。ということで以下の loss を採用する。

こうなってくるとそもそもの F1 score の拡張だった話とかはどこ吹く風という感じだが、まあこういう modification でも結果がよければそれでヨシというのが昨今の機械学習ということなんでしょう。Focal loss とかもそんな感じだったしね。

yoheikikuta commented 4 years ago

この論文で登場した loss をまとめたものが以下の表。(1-p) factor は Tversky loss でも同様に考えられるはずだけど、dice loss のみ試しているみたい。DSC の分子の factor 2 が落ちてるけど typo ですね。

loss の微分をグラフにしたものが以下。 ここはちょっと理解し切れてないが、まず DL や DSC の微分の計算としては y = 1, γ = 0 の場合で考えて振る舞いをみている。ただ overall factor がおかしく、DL の場合は 1/2 したものになっていそう。DSC も同様かと思ってたけど、上で述べたように分子の factor 2 がないのが typo じゃなかったとしたらこれで合ってる。まあ overall factor として処理できるところなのでそんなに神経質にならんでもいいかもしれない。 グラフの振る舞いは DCS だけが定性的に異なっていて、p=1/2 で 0 になりそこで符号が変わるようになっている。論文の説明だとこれが p < 1/2 の instance である場合は p = 1/2 になるように学習を進め、そこを超えた(正しく認識できるようになった)らそのような instance の寄与は小さくなるから無視されるようなもんだと書いてある。 なんかそれっぽく書いてあるが、これは p > 1/2 の instance も p = 1/2 に向かうように学習されるので、学習が安定しないんじゃないかと思うのだが。loss も p = 1/2 で小さい値をとるようになっているのでそれはいいのだが、適切に見分けられるように学習されるのかはちと疑問。

yoheikikuta commented 4 years ago

実験は POS tagging, NER, machine reading comprehension, paraphrase identification の 4 つで、BERT や XLNet に提案手法を含む loss を組み合わせて試して結果が改善するかを見ている。

よくなっていますということなのだが、そんなに顕著でもなくめちゃくちゃ面白い結果というわけでもないので、machine reading comprehension の結果だけを貼っておく。

yoheikikuta commented 4 years ago

Ablation study では Spacy を使って DBpedia を介して entity を置き換えて pos/neg を augment してデータバランスを変えたりして結果がどうなるか調べている。the Quora Question Pairs (QQP) はもともと pos:neg = 0.37:0.63 の imbalance なデータで、これを 0.5:0.5 にしたり 0.21:0.79 にしたりしている。そのほかにも negative sampling なども試しているが、結果は以下。

一番上が cross entropy での結果で、どのパターンでも単純な cross entropy よりも効果がある。 一方で問題が簡単で差が小さいので DSC が Focal loss や DL よりも優位に良い結果かと言われるとそうでもない。

あと気になったのが imbalance 度合いが大したことないとい点。 1:4 くらいにして試しているが、現実の imbalance 度合いはもっと極端なので、そういう状況設定で考えて欲しいところではある。object detection では極端な imbalance 度合いになるので Focal loss がめちゃ効いたわけだけど、そういう状況設定になってないのでちょっとアイデア拝借しましたという感じが否めない。

yoheikikuta commented 4 years ago

ということで一通り読んだ。 Dice loss に Focal loss のアイデアを組み合わせて、loss を改善して実験したら良い結果だったという論文。

imbalance な状況設定に興味があって読んだが、論文ではそんなに激しい imbalanced dataset の状況を考えていなかった。