shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

TinyGAN: Distilling BigGAN for Conditional Image Generation #216

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2009.13829] TinyGAN: Distilling BigGAN for Conditional Image Generation

著者・所属機関

Ting-Yun Chang, Chi-Jen Lu

投稿日時(YYYY-MM-DD)

2020-09-29

1. どんなもの?

2. 先行研究と比べてどこがすごいの?

GANはImageNetなどの大規模であり複雑なデータセットで学習を行うことは難しい.BigGANなどの非常に巨大なモデルに大きなバッチサイズを与えて学習を行うことで,ImageNetでも画像生成が成功することが実験的にわかっている.

軽量なモデルでも画像生成を行うために,BigGANなどのモデルを軽量化することが考えられる.分類タスクでは,モデルの軽量化のために蒸留・枝刈り・量子化などの手法が存在している.

本研究ではGANのモデルを蒸留によって軽量化することを考えている.

具体的には学習済みのBigGANを,潜在変数(とクラス情報)と出力画像のペアを生成するBlack-Boxと考え,生徒モデルはこの入出力のペアを真似るように学習を勧めていく.

image

この手法の利点は以下の3点である.

  1. モデルのパラメータや内部にアクセスする必要がない
  2. 学習時には事前に入出力のペアを取得しておくことで,学習済みモデル自体を使用しなくてもいい
  3. 入出力ペアさえあればいいので,どのようなモデル構造でも使用可能である

また学習を行う際には,教師モデルの出力と生徒モデルの出力に対してピクセル単位でのL1損失関数,教師モデルと生徒モデルの分布を識別する敵対的損失関数,モデルの特徴量ごとに計算を行う特徴量単位での損失関数を採用している.

image

3. 技術や手法の"キモ"はどこにある?

3.1 BigGAN Distillation

Pixel-Level Distillation Loss

教師モデルの入出力ペアを学習させる単純な方法は,教師モデルと生徒モデルの出力画像に対してピクセル単位でのL1損失関数を計算することである.

image

この際に教師モデル(BigGAN)のパラメータは固定を行い,潜在変数はBigGANの推論時に使用されている切断正規分布を使用する.

しかし,L1損失関数のみを使用して学習を行った場合,以下のように生成される画像はぼやけてしまう.

image

Adversarial Distillation Loss

生成される画像をより本物にちかづける(Shapeにする)ために,生成された画像が生徒モデルのものか教師モデルのものか識別するDiscriminatorを導入する.

生徒モデル
image
教師モデル
image

なおDiscriminatorには,実験中最も性能が良かったヒンジ損失関数を使ったProjection Discriminatorを採用している.

Feature-Level Distillation Loss

ぼやけた画像が生成されないようにするために,特徴量レベルでの損失関数も提案している.

著者らは生成画像のSourceを識別するDiscriminatorは,画像に関する重要な特徴量を獲得できているはずだという仮定を設けている.

つまりBigGANが生成した画像から得られるDiscriminatorの特徴量と,TinyGANが生成した画像から得られたDiscriminatorの特徴量は似たものになるはずである.この前提をもとに以下のような特徴量レベルでの損失関数を提案している.

image

なおより深い層から得られる特徴量に,より大きな重みを課している.

ここまでで提案されたすべての損失関数を以下のような図にまとめることができる.

image

3.2 Learning from Real Distribution

BigGANでのモード崩壊を改善するために,ImageNetの実データも使用して学習を進めていく.

image

3.3 Network Architecture

Generator

いくつかのGeneratorを試した中,クラス条件付きのBatchNormを使用しており,ResNetベースの構造のモデルがより高い性能を発揮することがわかっている.

また計算コストを低減するために,Attention機構やProgressive-Growing構造は採用しておらず,また通常の畳み込み層に関しても,チャンネル数を削減したDepthwise畳み込み層を使用している.

こうした工夫を行うことでBigGANよりも16倍程度の軽量化に成功している.

クラス条件付きのBatchNormは各層で同一のものを使用している. これらの構造をまとめると以下のようになる.

image

Discriminator

Discriminatorに関しても,スペクトル正規化やProjection構造は採用したまま,Resnet構造を採用せず,DCGANのように畳み込み層を単純に重ねたものを使用することで,Discriminatorのサイズを10倍程度軽量化している.

4. どうやって有効だと検証した?

image

image

image

image

image

image

5. 議論はあるか?

shimopino commented 4 years ago

https://github.com/terarachang/ACCV_TinyGAN