Open keimoriyama opened 2 years ago
author
link
Correlation Congruence for Knowledge Distillation
重いニューラルネットワーク(教師モデル)の性能をパラメータの小さいネットワーク(生徒モデル)で近似する手法に知識蒸留がある。
既存の知識蒸留の問題点として、教師モデルの特徴空間が捉えられていないという問題がある。 なので、データセット内の特徴やクラス間の特徴を使った蒸留手法を提案した。
手法は以下のような感じで実装されている。 CCKDは2つから構成されていて、インスタンス間の損失とインスタンス間の相関の損失を計算している。
$\mathcal{f}_s$を生徒モデルの特徴表現、$\mathcal{f}_t$を教師モデルの特徴表現とする。 $F_t$と$F_s$を教師と生徒の特徴表現をまとめた行列とする。
ここから相関行列$C$を計算する関数$\psi$を次のように定義する。
$C$の$(i, j)$版目の要素は式5で計算する。 $varphi$は相関メトリック(?)で、なんらかの計算をする関数だと思う。
これまでの計算を用いて、correlation congruence損失を式6で計算する。 インスタンス間の相関関係を蒸留していると感じる
全体の損失は式7の通り。 $\alpha$はハイパラ、$\mathcal{L}{CE}$はcross entropy損失、$\mathcal{L}{KD}$はHintonらの知識蒸留の損失。
インスタンス間の相関を計算するために、カーネルトリックを使った手法を提案する。 $x,y \in \Omega$を2つのインスタンスとして、$k: \Omega \times \Omega \rightarrow R$をマッピング関数とすると3つ提案する。 MMDは埋め込みベクトル間の距離を指標に使っている。 Bilinear Poolは内積で計算する。
Gaussian RBFはP次テイラー展開を使って次のように計算する。
適切なバッチサンプリング関数はクラス内、クラス間のcorrelation congruenceのバランスをつるのに重要になる。 これはミニバッチ内のインスタンスで相関が計算されているから。
2つのサンプリング手法を提案していて、class-uniform random sampler(CUR-sampler)とsuperclass-uniform random sampler(SUR-sampler)がある。
class-uniform random sampler
superclass-uniform random sampler
CUR-samplerではそれぞれのクラスのデータをkサンプルずつ取得する。 SUR-samplerではsuperclassという概念を導入する。superclassはクラスタリングによって得られる。 最初に、学習データから教師モデルを使って特徴を抽出する。 その後に、K-meansを使ってクラスタリングをして、superclassを決定する。
複数のタスクにおいて実験を行い、他の蒸留手法と比較した。
使うデータセットCIFAR-100とImageNet-1K。
Table1はCIFAR-100での実験結果。 既存手法よりも良い精度を達成していることがわかる。
Table2はImageNet-1Kでの実験結果。 既存手法よりも生徒モデルの精度が良くなっている。
MSMT17データセットとMegaFaceを使ってパフォーマンスを比較する。 MSMT17データセットはReID(Person Re-Identification)、MegaFaceは顔認識タスクをやっている。
MSMT17データセット、MegaFace共に既存手法よりも良い精度を達成した。
<img width="642" alt="Screen Shot 2022-06-28 at 11 23 24" src="https://user-images.githubusercontent.com/37134200/176077902-86741c74-00c1-4eda-85f6-cf18b73bc6cb.png"
相関を測る関数を変えて、モデルの精度がどれくらい変化するかを比較した。 関数は、MMD、Bilinear, Gaussian RBFの3種類。 結果はGaussian RBFが最も精度が良かった。
Gaussian RBFにおいてテイラー展開で近似する際の、項数を変化させてパフォーマンスを比較した。 結果は、項数を増やすと精度がその分良くなることがわかった。
Correlation Congruence for Knowledge Distillation
author
link
Correlation Congruence for Knowledge Distillation
背景:なぜその問題を解決したいのか
重いニューラルネットワーク(教師モデル)の性能をパラメータの小さいネットワーク(生徒モデル)で近似する手法に知識蒸留がある。
目的:どういう問題を解決したのか
既存の知識蒸留の問題点として、教師モデルの特徴空間が捉えられていないという問題がある。 なので、データセット内の特徴やクラス間の特徴を使った蒸留手法を提案した。
提案:解決に向けたキーアイデアは何か
提案手法
手法は以下のような感じで実装されている。 CCKDは2つから構成されていて、インスタンス間の損失とインスタンス間の相関の損失を計算している。
$\mathcal{f}_s$を生徒モデルの特徴表現、$\mathcal{f}_t$を教師モデルの特徴表現とする。 $F_t$と$F_s$を教師と生徒の特徴表現をまとめた行列とする。
ここから相関行列$C$を計算する関数$\psi$を次のように定義する。
$C$の$(i, j)$版目の要素は式5で計算する。 $varphi$は相関メトリック(?)で、なんらかの計算をする関数だと思う。
これまでの計算を用いて、correlation congruence損失を式6で計算する。 インスタンス間の相関関係を蒸留していると感じる
全体の損失は式7の通り。 $\alpha$はハイパラ、$\mathcal{L}{CE}$はcross entropy損失、$\mathcal{L}{KD}$はHintonらの知識蒸留の損失。
一般化したカーネルベースの相関
インスタンス間の相関を計算するために、カーネルトリックを使った手法を提案する。 $x,y \in \Omega$を2つのインスタンスとして、$k: \Omega \times \Omega \rightarrow R$をマッピング関数とすると3つ提案する。 MMDは埋め込みベクトル間の距離を指標に使っている。 Bilinear Poolは内積で計算する。
Gaussian RBFはP次テイラー展開を使って次のように計算する。
サンプリングの設計
適切なバッチサンプリング関数はクラス内、クラス間のcorrelation congruenceのバランスをつるのに重要になる。 これはミニバッチ内のインスタンスで相関が計算されているから。
2つのサンプリング手法を提案していて、
class-uniform random sampler
(CUR-sampler)とsuperclass-uniform random sampler
(SUR-sampler)がある。CUR-samplerではそれぞれのクラスのデータをkサンプルずつ取得する。 SUR-samplerではsuperclassという概念を導入する。superclassはクラスタリングによって得られる。 最初に、学習データから教師モデルを使って特徴を抽出する。 その後に、K-meansを使ってクラスタリングをして、superclassを決定する。
結果:結局問題は解決されたのか.新しくわかったことは?
実験
複数のタスクにおいて実験を行い、他の蒸留手法と比較した。
画像認識
使うデータセットCIFAR-100とImageNet-1K。
Table1はCIFAR-100での実験結果。 既存手法よりも良い精度を達成していることがわかる。
Table2はImageNet-1Kでの実験結果。 既存手法よりも生徒モデルの精度が良くなっている。
メトリックラーニング
MSMT17データセットとMegaFaceを使ってパフォーマンスを比較する。 MSMT17データセットはReID(Person Re-Identification)、MegaFaceは顔認識タスクをやっている。
MSMT17データセット、MegaFace共に既存手法よりも良い精度を達成した。
<img width="642" alt="Screen Shot 2022-06-28 at 11 23 24" src="https://user-images.githubusercontent.com/37134200/176077902-86741c74-00c1-4eda-85f6-cf18b73bc6cb.png"
アブレーションスタディ
相関の指標
相関を測る関数を変えて、モデルの精度がどれくらい変化するかを比較した。 関数は、MMD、Bilinear, Gaussian RBFの3種類。 結果はGaussian RBFが最も精度が良かった。
Gaussian RBFにおけるテイラー展開の項数
Gaussian RBFにおいてテイラー展開で近似する際の、項数を変化させてパフォーマンスを比較した。 結果は、項数を増やすと精度がその分良くなることがわかった。