keimoriyama / archive-paper_management

0 stars 0 forks source link

FitNets: Hints for Thin Deep Nets #2

Open keimoriyama opened 2 years ago

keimoriyama commented 2 years ago

FitNets: Hints for Thin Deep Nets

[1412.6550] FitNets: Hints for Thin Deep Nets

背景:なぜその問題を解決したいのか

モデルのパフォーマンスを向上させる上で,隠れ層の深さは重要になる.画像分類や物体検知においてSOTAを達成したモデルの隠れ層は深い.

このようなモデルは推論時に時間と計算リソースを大量に必要とするので,実用的ではない.

目的:どういう問題を解決したのか

この問題を解決するために,大きな教師モデルから小さな生徒モデルに知識を移す手法が提案された

この既存手法では教師モデルと生徒モデル隠れ層の数と深さが同じである.(教師モデルはアンサンブルになっているよね!)

これは隠れ層を深くすることによる利点を生かしきれていないといえる.利点は以下の通り

  1. 理論的に:層が深いと表現の幅が大きく増える
  2. 経験的に:パフォーマスが最も良い畳み込みを用いているImageNetのレイヤーの数は19,22

提案:解決に向けたキーアイデアは何か

教師モデルの隠れ層のパラメータをヒントとして生徒モデルの隠れ層のパラメータを調整する(式3)

教師モデルと生徒モデルの隠れ層の次元数は違うので,生徒モデルの隠れ層の次元数を教師モデルに合わせるための回帰関数rを用いる.(畳み込み回帰をすることでこの回帰に必要なパラメータを大きく削減した)

学習時には,HT→KDの順番にパラメータを更新する.HTで隠れ層のパラメータをチューニングして,KDでモデル全体のパラメータを調整する

カリキュラムラーニングの一般系になっていると言えるよ!

結果:結局問題は解決されたのか.新しくわかったことは?

実験

CIFAR-10&CIFAR-100

教師モデル:畳み込み+maxoutレイヤーを2層重ねたモデル

生徒モデル:畳み込み+maxoutレイヤーを17層重ねたモデル(教師モデルの1/3のパラメータ数)

教師モデルの2つ目のレイヤーをヒントとして,生徒モデルの11個目のレイヤーにヒントを与える

SVHNデータセット

教師モデル:畳み込み+maxoutレイヤーを2層重ねたモデル

生徒モデル:畳み込み+maxoutレイヤーを11層+完全結合レイヤ+softmaxレイヤの計13層のモデル

MNIST

学習プロセスの正しさを検証するために行った(誤差逆伝播法とかと比較したい)

教師モデル:畳み込み+maxoutレイヤーを2層重ねたモデル

生徒モデル:畳み込み+maxoutレイヤーを4層重ねたモデル(教師モデルの8%のパラメータ数)

教師モデルの2つ目のレイヤーをヒントとして,生徒モデルの4つ目のレイヤーを学習させた

AFLW

これまではmaxoutレイヤーを使った畳み込みをモデルとして使ってきたけど,違う構造をしたモデルでもうまくいくことを試したい

教師モデル:畳み込み+ReLUを3層重ねたモデル

生徒モデル1:畳み込み+ReLUを3層重ねたモデル(教師モデルの1/15のパラメータ数)

生徒モデル2:畳み込み+ReLUを3層重ねたモデル(教師モデルの1/2.5のパラメータ数)

考察

たまたまヒントとして与える教師レイヤーの値が優秀だったのでは?

ヒントはあくまで良い出力を得るためのものという雑な説明がある

他の視点から

  1. 1段階目(HT)での学習はネットワークの前半部分を最適化,2段階目(KD)でネットワークの最適化をしている
    1. 1段階目の学習が必ずしも,2段階目の学習に役に立っているとは限らない
      1. ヒントに誘導された隠れ層のレイヤーが入力の特徴を無視してしまう可能性があるため
  2. ヒントとして隠れ層の出力を合計したものを用いてみた

    1. この方法で学習はうまくいかなかった
    2. レイヤー全体にヒントを与えるDSNという手法がある

      1. 表に19層のDSNでのスコアがある
      2. FitNetsよりもスコアがよくない

        →判別に使えるヒントは正則化として強すぎることがわかる

他の学習手法とこの手法の比較

手法は以下の3つ

誤差逆伝播法,KD,HT

モデルの層と畳み込みのチャネル数を変化させてモデルのパフォーマンスを比較した

使うデータセットはCIFAR-10

計算リソースに制限をかけているという認識でOK