Oord, A. v. d., Vinyals, O., and Kavukcuoglu, K. Neural discrete representation learning. In Advances in Neural Information Processing Systems, 2017
Liang, K. J., Li, C., Wang, G., and Carin, L. Generative adversarial network training is a continual learning problem. arXiv preprint arXiv:1811.11083, 2018
論文へのリンク
[arXiv:2004.02088] Feature Quantization Improves GAN Training
著者・所属機関
Yang Zhao, Chunyuan Li, Ping Yu, Jianfeng Gao, Changyou Chen
投稿日時(YYYY-MM-DD)
2020-04-05
1. どんなもの?
GANを学習させる際に生じる不安定が、常に一定のラベル分布と変化する生成分布間のバランス差で生じるミニバッチ間の統計量の不一致だとした。
実データと生成データを両方とも同じ離散空間で量子化させることで、直近の分布に対して一定の値を返すKey-Valueの役割を持たせた。
2. 先行研究と比べてどこがすごいの?
GANは高次元空間でのナッシュ均衡を最適化させるため最適化が難しく、SGDによる最適化では学習が不安定になってしまうことが知られている。
著者らはこの不安定性の原因をミニバッチによる学習に起因するものだと指摘している。具体的には以下の3点があげられる。
本研究では実画像と生成画像の移動平均を用いて、Discriminator内の特徴量を離散化させることで、ミニバッチの変化に対して頑強なFQ-GANを提案した。
3. 技術や手法の"キモ"はどこにある?
3.1 From Continuous to Quantized Representations
Discriminatorで行う計算をまとめると、まず画像からD次元の特徴ベクトルをボトルネックのネットワークから抽出し、その後にこのベクトルを入力に画像が本物か偽物かを判定する。
この特徴量を離散化させるためにまずは離散空間を定義する。それぞれD次元のベクトルeをK個用意し特徴量を離散化させる。離散化の際にはボトルネックから抽出した特徴ベクトルを、最も近い離散空間上のベクトルに変換する。
つまり最終的な計算の流れは以下のように中間に特徴ベクトルを離散化させるネットワークを挟む形になる。
図示すると以下のようになる。
3.2 Dictionary Learning
3.2.1 Loss
離散空間上で学習を行うために先行研究[Oord et al., 2017]で提案された2つの損失関数を採用している。
1つ目の損失関数では離散化させたベクトルと離散化前のベクトルのL2損失を計算することで、選ばれた離散ベクトルeが離散化前の特徴ベクトルに近くなるように学習を行う。
2つ目の損失関数ではボトルネックから抽出される特徴量が離散ベクトルに近くなるようにL2損失を計算することで、特徴ベクトルの変化によって離散ベクトルが大きく変動してしまうことを防いでいる。
なお以下の数式で表現されているsgとはstop-gradientを表しており、勾配の伝播が引数のようにまで到達しないようにしている。
3.2.2 A dynamic & consistent dictionary
先行研究[Liang et al., 2018]からDiscriminatorの学習は常に変化することが指摘されており、Generatorから生成される画像の統計量は学習とともに変化するため、Discriminatorはこの変化する分布を学習する必要がある。
そこで離散化させる際にDiscriminatorが学習初期のGANの統計量を学習させないため、離散ベクトルとしてをqueueのデータ構造を採用し、最も直近のミニバッチから得られた離散ベクトルをenqueueし、最も古いミニバッチから得られた離散ベクトルはdequeueするようにしている。
3.2.3 Momentum update of dictionary
損失関数のうちDisctionary Lossを計算し離散ベクトルを更新する際に移動平均を利用する。n個のサンプルのうち、nk個の特徴量がekに離散化する場合に、更新は以下の数式で表現される。
これで離散ベクトルekの学習はSmoothになることが期待される。なおλが0の場合は直近のミニバッチのみを参照することを意味しており、実験ではλ=0.90に設定している。
3.3 FQ-GAN Training
最終的に損失関数は以下の数式で表現される。
また全体のアルゴリズムは以下のように表現される。
本手法の特徴の1つはそのScalabilityであり、ミニバッチの統計量を学習ではなく離散ベクトルの学習に置き換えることで、離散ベクトルの数を変化させることでデータセットのサイズにRobustに対応することができる。
また本手法は実データも生成データも互いに共有している離散空間に投影されるため、暗黙的に特徴量のマッチングを行っていることになり、以下の図のように特ベクトルを特定のセントロイドに投影する。
3.4 FQ-GAN for image generation
画像ドメインに本手法を適用することを考える。通常CNNから抽出できる特徴ベクトルのサイズはCxLxWである。離散化させる際には、画像中の1点を特徴量を使用し、チャンネル数と同じ数の次元数を有する離散ベクトルで置き換える。
4. どうやって有効だと検証した?
本手法をBigGAN・StyleGAN・U-GAT-ITに適用した。
CIFAR100に対して本手法を適用した際に評価指標がどのように変化するのかを検証した。
次にBigGANを用いて、CIFAR10・CIFAR100・ImageNetに本手法を適用した場合の評価指標の比較を行う。どのデータセットに対しても改善効果が表れており、本しゅほが有効に働いていることがわかる。
同じEpoch数を学習させた際の計算コストをまとめる。どのデータに関しても学習コストの増大は無視できる程度であることがわかる。
また本手法を使用したほうがより早く学習が収束していることがわかり、同程度の性能をより短時間で得られることがわかる。
次にImageNetの各カテゴリの画像を生成した際の評価指標を比較する。実画像と生成画像を離散化させたベクトルの分布間距離をMMDで計算した結果と、クラスを指定した生成した画像のFIDを学習済みInceptionモデルで計算することで実画像の分布にどの程度近いのか検証した。
結果を見ると本手法を導入することで、どちらの指標に関して優位な結果を示していることがわかる。
FFHQデータセットを使用してStyleGANに本手法を適用した際の結果を検証する。どの解像度の画像生成に関しても、元のモデルよりも高い性能を発揮していることがわかる。
スタイル変換モデルのSOTAであるU-GAT-ITに対して本手法を適用した場合の結果を比較する。どのデータセットに対しても高い性能を発揮していることがわかる。
ユーザー調査を行い生成結果の比較を行った場合でも本手法が既存モデルをどのデータセットに対しても高い性能を発揮していることがわかる。
生成が追うの比較を行うと以下のようになる。本手法を使用したモデルのほうが、よりTargetのドメインに合う画像を生成できているような気もする。
5. 議論はあるか?
6. 次に読むべき論文はあるか?
論文情報・リンク