yoheikikuta / paper-reading

Notes about papers I read (in Japanese)
156 stars 4 forks source link

[2019] When Does Label Smoothing Help? #39

Open yoheikikuta opened 4 years ago

yoheikikuta commented 4 years ago

論文リンク

https://arxiv.org/abs/1906.02629

公開日(yyyy/mm/dd)

2019/06/06

概要

label smoothing が何を達成していてどういう場合に有効かを調べた論文。 penultimate layer の出力を調べると、label smoothing によって各クラスがよりはっきりと固まる(凝集する)ようになり、これは calibration (model confidence と accuracy が相関している) が良くなる方向に寄与する。 これは汎化性能を上げたり、beam-search が向上することで機械翻訳などで有用になったり、という利点がある。この辺は label smoothing だけでなく temperature scaling も合わせて分析している。 しかし、蒸留において teacher network に label smoothing を適用すると student network の性能が劣化することを発見。 これは上述のように学習した表現がより凝集することで、異なるクラス間の相互情報量が減少し、それが蒸留において必要なものだったと述べている。

yoheikikuta commented 4 years ago

label smoothing のようなちょっとしたテクニックっぽいのでモデルの性能を上げる、というのはあまり関心を惹かなかったが、解釈ができるものならば面白いかもと思って読んでみることにした。

yoheikikuta commented 4 years ago

まず label smoothing とはなんだったのかを思い出しておく。 K classes classification で、$ p_k $ がモデルの出力であるクラス k の確率で、$ y_k $ は正しいクラスなら 1 でそうでなければ 0 とする。label smoothing は以下。

ほとんどのモデルでそうであるように、$ p_k $ は penultimate layer の出力 x に対して $ w^T x $ を計算してそれを softmax 関数に入れた結果になっている($ p_k = softmax(w^T x) $)。ここで x は bias を含めるために 1 を concat している。

こいつの意味を考えるのは論文の主題でもあるので少し先送りにする。 label smoothing は簡単な割に汎化性能が向上するということで色々なモデルで使われていた。例えば以下。

yoheikikuta commented 4 years ago

論文の内容には直接関係ないが、logit という単語について思いを馳せてみる。

この論文で言うところの logit はずばり $ w^T x $ で softmax 関数への入力値のことである。 このような意味での logit は割と広く使われるもので、例えば TensorFlow とかはこのように使っている。

しかし、logit といえば $ logit(p) = \log(p / (1-p)) $ だろうと思っていた人にとってはちょっと confusing になる。もともとある値を sigmoid 関数に入れることで [0,1] に map して確率と見なすことが多かったので、このある値のことを logit と定義して、すなわち表式は sigmoid の逆関数になっていたのだろう。

昨今のディープラーニング界隈では sigmoid より softmax を使うことの方が多いので、むしろ softmax への入力を logit と呼ぶことの方が自然なのだろう。

yoheikikuta commented 4 years ago

label smoothing の表式が意味するところを考えてみる。

正しいラベル($ y_k = 1 $)に対しては確率を 1 に近づけるように、正しいクラスの logit の値は大きくして、それ以外のクラスの logit の値は小さくする。このとき、重要なのは正しいクラスの logit の値を大きくすることで、それ以外のクラスの logit をどう小さくするかは(比較的)小さな問題で、特に、それ以外のクラスの logit はバラついていてもよい(それを抑制する働きをするものはない)。そしてこれは label smoothing ありなしに関わらず同様である。

一方で、正しくないラベル($ y_k = 0 $)に対しては、$ α / K $ term が存在することで logit を等しくするように学習が進む。めちゃくちゃ簡単な例で言えば $ -(log(x) + log(1 - x)) $ は x = 0.5 が最小という話。

この正しくないラベルに対して logit を等しくするように、という部分をもう少し掘ってみる。 logit で計算しているのは $ x^T w_k $ であるが、これは $ |x - w_k|^2 $ の cross term である。$ x^T x $ は weight に寄らないので softmax の計算では factored out される。さらに $ w_k^T w_k $ は(規格化をすれば)クラス間で変化はなく一定で、これも factored out される。したがって $ x^T w_k $ の計算は実質的には $ x $ と $ w_k $ のユークリッド距離を測っていることになる。 この観点を持ち込むと、あるデータ x は正しいクラスの weight $ wk $ には近くなって、異なるクラスの weight $ w{k'} $ とは遠くなる。しかも異なるクラスの logit は等しくなるようにするのが label smoothing だったので、等しく離れるようになる。

このように、label smoothing には「同一クラスの特徴量は近い点に集まるようにして、それぞれのクラスは等しく離れるようにする」という働きを有することが分かる。

yoheikikuta commented 4 years ago

絵を見せよう。論文で提案された作り方は以下。

1, 2 行は CIFAR10, 100 の結果で確かに Label Smoothing (LS) ありでそれぞれのクラスが等しく離れるように map されていることが分かる。 3, 4 行は ImageNet だが上は 3 つとも似てないクラスで下が 2 つは似てて 1 つは似てないクラスを選んだ結果。LS の loss だけを見ればラベルが異なれば等しく離れるようにするものだが、似てるものはモデルは混同しやすいので特徴量も近いところに map されていることが分かる。

このプロットの仕方は論文の一つの貢献とのこと。これがどれくらい価値あるものかはちと分からない。 1 つの weight は高次元空間中の点なので、3 点選べば平面が描けるということで、そういうのを切り出して綺麗に見せられるようになったのは新しいのかな?特徴量をガッと集めて PCA とか t-SNE とかを使う、みたいな話ではないので。

これらのモデルの汎化性能は以下。 LS は汎化性能に寄与することがあるが、ここで試した結果としては優位な差はなさそう。定義から明らかだが、α = 0 は LS なしの cross entropy となる。

yoheikikuta commented 4 years ago

次に implicit calibration の効果を見る。

まず、calibration が意味するところを明らかにしておく。 自分はこの言葉遣いを見るのは初めてだったが、話としては単純。 出力の confidence を 15 bin に区切り、bin 毎に accuracy を計算してプロットして、その相関が高いほどよく calibrate されているというようだ。つまり confidence が高いほど確かに当たっているという状況がよりよく calibrate されていることになる。

ということで LS の効果を見ていきたいが、以降では LS 以外にも Temperature Scaling (TS) というものも使う。これは softmax を計算するときに温度を入れるもので logit -> logit / T にするもの。T > 1 にすると指数の値は小さくなるので、よりはっきり自信があるものを予測する方向に学習が進むということ。

画像分析に使ってみた結果は以下。 確かに LS を入れることで calibrate されるようになるが、TS の方が効果が大きい。まあ implicit なものなのでこんなもんで、実はこういう効果があるのも分かったんだぞ、という感じ。

次に機械翻訳に使ってみた結果。 モデルの出力は beam-search に使われる、というものになっている。beam-search は良さそうな候補を絞って探していくものなので、よく calibrate されている方が適切な候補に絞れるので良くなるだろうという期待があるためだ。 実際にそうなっていて、TS と同じくらいもしくはそれ以上の効果が出ていることが見て取れる。

両方合わせると、既に calibrate されているものに温度の効果も重ねることで性能が落ちてしまうことも示されている(図は割愛)。

yoheikikuta commented 4 years ago

ここまでは LS が良い効果を及ぼすものに関して言及してきた。 以降の話は蒸留において悪影響を及ぼし得るというもの。

セットアップは CIFAR-10 のデータで teacher: ResNet-56, student: AlexNet というもの。

蒸留では cross entropy を $ (1 - β) H(y, p) + β H (p^t (t), p(T)) $ として、$ β $ はバランスを取るためのパラメタで、$ p(T), p^t(T) $ はそれぞれ TS した後の student, teacher の出力とする。 論文の実験では $ β = 1 $ として teacher output と true label の cross entropy は無視して、teacher と student の関係性のみに注目する。LS によって teacher network は汎化性能を高める方向(少なくとも悪化はしない)に進むものなのでこれは reasonable だろう。

知りたいことは、例えば teacher を LS で学習して、student はその出力を使って普通に学習したら良くなるのか?ということだ。 そうすると student の学習時にラベルに付加されている LS は teacher の出力に依るので、α のような単純な parametrization にはなってない。そこで以下のように定義をしてこれを student から見たときの LS の指標とする。

以上を踏まえて、

である。

結果は以下。 (1), (2) は少し上がる部分もあり、これは単純に LS を使って学習すると汎化性能が高まり得るということを意味している。 一方で (3), (4) は、そもそも (2) よりも明白に悪くなってしまっている。これは蒸留で teacher network に LS や TS を使うことで student network の性能が悪くなってしまうということ。性能向上を見込んで LS 使ってしまうと逆に悪化させてしまうということで注意が必要である。teacher の性能が良ければ良いほど student が賢くなるということではない!これはなかなか示唆的だ。 特徴量空間のプロットを思い出してみれば、ある1つのクラスに属するサンプルはどれもギュッとまとまって同じようなところに map されるので、それを使って学習する student は同じようなデータを使ってしか学習できなくなってしまうという感じだ。

この結果はかなり明白なので、何かちゃんとした理由がつけられそうに感じる。 論文では、それは logit 間の相対的な情報量が落ちてしまったためである、ということで以降にその説明が続いていく。

yoheikikuta commented 4 years ago

logit 間の相対的な情報量が落ちて、の部分だが色々仮定を入れて以下のようなものを計算する。

相互情報量は以下のように計算できる。

ごちゃごちゃ書いたが、もしあるクラスのデータの penultimate 出力がほとんどそのクラスの平均のところに集まるならば、それはクラスを given にすれば penultimate 出力がほぼ分かるので相対エントロピーが 0 に近いということ。 強めの仮定を入れて話を単純にしたので、その辺を認めれば解釈は簡単だ。

これを計算した結果が以下。 penultimate 出力の 2D プロットからほぼ自明に想像できる通り、LS を使うと相互情報量が落ちている。そして上で述べたように同じような teacher の教えでしか学習できなければ student は賢くならないということになっているという話に繋がっているわけである。

正直この結果はこうなることが分かっているものを計算した感はあるが、情報論的な話と結びつけたという点はなかなか面白く、蒸留などもそういう方向性の研究が進められていくかもしれない。論文でも information bottleneck の話を出したりして情報論的にも意義がある仕事だぞとアピールしている。

yoheikikuta commented 4 years ago

ということで一通り読んでみた。 ちょっとインパクトは弱い感じがしないでもないが、単発というより今後にも繋がっていく可能性のある論文という意味ではそれなりには面白かった。

個人的には LS を汎化性能向上で使ってもそんなに効かないんかなぁという印象を受けた(自分では使ったことないので)。 それよりはクラス毎に集まるような状況を利用して abstention mechanism とかと組み合わせた方が面白そうな気はする。