shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Contrastive Generative Adversarial Networks #163

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2006.12681] Contrastive Generative Adversarial Networks

著者・所属機関

Minguk Kang, Jaesik Park

投稿日時(YYYY-MM-DD)

2020-06-23

1. どんなもの?

クラスラベルを使用した教師あり設定でのGANでの学習に、SimCLRで導入されたXT-Xent Lossを組み込んだ新たなフレームワークを提案した。これはDiscriminatorから抽出した特徴ベクトルとクラスベクトルのXT-Xent Lossを計算することで、Discmriminatorがよりよい特徴表現を獲得することを目的にしている。

評価指標は改善されているが、SimCLRと同様に大きなバッチサイズと長い学習時間が必要となる。

2. 先行研究と比べてどこがすごいの?

従来の研究でクラスラベルなどの条件付きで学習を行うことで、生成画像の質が大きく向上することが判明しており、多くのGANはcGANの構造を取り入れている。

cGANの構造

image

ACGANの構造

image

本研究では条件付き学習の新たなフレームワークとなるContrastive Generative Adversarial Networks (ContraGAN) を提案している。

これは同一クラスからサンプリングされた画像の潜在表現の相互情報量の下限を最大化することで、同一クラスの画像が近い距離に分布し、異なるクラスの画像が離れた距離にあるような特徴空間を獲得することを目的にしている。

image

3. 技術や手法の"キモ"はどこにある?

3.1 Maximizing Lower bound on Mutual Information

相互情報量とは確率変数XとYは、どの程度の情報を共有しているのかを示す指標である。そのためXとYが互いに独立している場合、相互情報量は0となる。

このことから同一クラスからサンプリングした画像から得られた特徴ベクトル同士の相互情報量は大きくなり、異なるクラスからサンプリングした画像から得れれる特徴ベクトル同士の相互情報量は小さくなることが期待される。

よってこの条件を満たすような特徴ベクトルを抽出可能なEncoderを定義していく。

ここで以下の命題を考える。

変数xとyをそれぞれ、ランダムに選択した画像の特徴ベクトルと画像に対応するクラスとする。次にxとyの同時分布から対応するペアをサンプリングする。 このときに以下が画像のペアの相互情報量の下限とする。

image

この相互情報量の下限の最大化は実装上では、画像から特徴ベクトルを抽出できるEncoderと2つの特徴ベクトルから同じクラスに属しているのか識別できる分類器が存在していることに等しい。

3.2 Conditional Contrastive Loss

ではGANで学習を行う際にDiscrimiantorとGeneratorの目的関数にContrastive Lossを組み込むことを考えていく。

その際にTriplet Lossなどを導入してしまうと学習過程が複雑になってしまい、学習時間が伸びてしまう。抽出した特徴ベクトルを使用した代替タスクを解くことも可能だが、こういった損失関数はデータ間の関係を捉えることができない。

そこでSimCLRで導入されたXT-Text Lossを考える。この損失関数の計算の流れは以下になる。

  1. まず入力されたデータに対してデータ増強を行う。
  2. 変換されたデータをEncoderを使用してk次元の特徴ベクトルを抽出する。
  3. Projectionを使用して、得られた特徴ベクトルを異なる次元dのベクトルとして超球面上に投影する。
  4. 投影されたベクトルを使用してXT-Xent Lossを計算する。

投影まで行うネットワークを l=h(S(x)) とすると、式では以下で表現される。

image

ここまで踏まえた上でどのようにGANに組むのかを考える。本研究ではDiscriminatorを特徴抽出を行う部分と分類を行う部分とに分けることで、以下の図のようにXT-Xent Lossを組み込んでいる。

image

なおそのまま損失関数を組み込むと、適切なデータ増強手法を選択したり、多大な学習時間が必要となってしまうため、データのクラスラベルをベクトルに変換して以下のように組み込んでいる。

image

この損失関数は入力された画像を特徴空間上で、クラスベクトル e(yi) の近くに投影されるように学習を進めていくことができる。

しかしこのままだと同じクラスに属する画像も負例として扱ってしまうため、同じクラスに属する負例のサンプルの同士のコサイン類似度を計算して分子に足し合わせている。

image

これで損失関数の最小化を行うと分子が大きくなる方向、つまり同じクラスのサンプル同士の距離も近づくように学習を行うことができる。

本研究ではこの損失関数をC2 Lossと言及している。

3.3 Contrastive Generative Adversarial Networks

全体のアルゴリズムは以下になる。

image

4. どうやって有効だと検証した?

CIFAR10とTiny ImageNetを使用して各モデルのFIDを比較した。本手法がどのモデルに対しても有効に働いていることがわかる。

image

様々な設定で学習を行った結果、SimCLRと同様に大きなバッチサイズと長い学習を行うことで最も評価指標を改善している。

image

5. 議論はあるか?

shimopino commented 4 years ago

https://github.com/POSTECH-CVLab/PyTorch-StudioGAN