shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Taming GANs with Lookahead #168

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2006.14567] Taming GANs with Lookahead

著者・所属機関

Tatjana Chavdarova, Matteo Pagliardini, Martin Jaggi, Francois Fleuret

投稿日時(YYYY-MM-DD)

2020-06-25

1. どんなもの?

勾配法を数回行った後で、パラメータを更新されたパラメータを線形補間した値に更新するLookaheadをGANに適用した研究。

本手法により、大きなミニバッチを使用することなく大きな分散に対処することができ、GANの性能を向上させることができると考えられる。

2. 先行研究と比べてどこがすごいの?

GANを勾配法で最適化させる場合、GeneratorとDiscriminatorを同じ目的関数に対してそれぞれを最適化させていき、最終的にナッシュ均衡に達することが理想となる。

課題としてはGANを学習させることは非常に難しく、モード崩壊や学習の発散、ハイパラの変更により性能が大きく変化してしまうことなどが挙げられる。そのためBigGANのようなSOTAなモデルを作成するためには、膨大な計算資源が必要とされているのが現状である。

3. 技術や手法の"キモ"はどこにある?

3.1 Algorithm

GANにLookaheadを適用する場合には、実装上はGeneratorとDiscriminatorにそれぞれLookaheadを適用して更新を行う。計算の流れとしては以下になる。

image

重要な点としてはslow weightsを更新する際には、GeneratorとDiscriminatorで同時に更新を行っていることである。

3.2 Motivating example: the bilinear game

4. どうやって有効だと検証した?

モデルの評価のためにMNISTやCIFAR10などの軽量なデータセットを使用してFIDとISの比較を行っている。この際に使用する最適化手法として、Adamと指数移動平均による更新規則を採用している。

image

以下ではMNISTで100万回Iterationさせた際の、GeneratorとDiscriminatorのへシアン行列の固有値と、JVFのヤコビアンの固有値である。図(c)を見るとわかるように、LA-GANでは学習が収束した後に、収束点付近で振動していないことがわかる。

image

次にCIFAR10を使用してBigGANとの比較を行った。興味深い点としてはクラスラベルを使用した教師あり学習の場合、BigGANを越えるFIDを達成していることである。

これはモデルサイズがBigGANが158.3MでありLA-GANが5.1Mであることを考慮すると驚異的であり、バッチサイズもBigGANの2048ではなく、128と現実的な値となっている。

image

学習の安定性を見てみてもLookaheadを使用することで発散することなく学習が進んでいることが分かる。

image

5. 議論はあるか?

shimopino commented 4 years ago

理論的な背景を掴むための論文

shimopino commented 4 years ago

もとのLookaheadのissueでまとめているコードの場合は、挙動に注意する必要がある。

以下のように元のLookaheadを使用すると、Joint Learningではなく論文のAppendix中にあるAlternate Learningになってしまう。

optG = Adam(netG.parameters(), lr=1e-3, betas=(0.9, 0.999))
optD = Adam(netD.parameters(), lr=1e-3, betas=(0.9, 0.999))
lookaheadG = Lookahead(optG, k=5, alpha=0.5)
lookaheadD = Lookahead(optD, k=5, alpha=0.5)

while  iteration < global_iteration:
    # discriminator
    lookaheadD.zero_grad()
    loss_D(netD(input)).backward()
    lookaheadD.step()

    # discriminator
    lookaheadG.zero_grad()
    loss_G(netD(input)).backward()
    lookaheadG.step()
shimopino commented 4 years ago

そこでJoint Learningを行うためにslow weightsfast weightsを更新するAPIは分けておくのが無難である。

optG = Adam(netG.parameters(), lr=1e-3, betas=(0.9, 0.999))
optD = Adam(netD.parameters(), lr=1e-3, betas=(0.9, 0.999))
lookaheadG = Lookahead(optG, k=5, alpha=0.5)
lookaheadD = Lookahead(optD, k=5, alpha=0.5)

while  iteration < global_iteration:
    # discriminator
    lookaheadD.zero_grad()
    loss_D(netD(input)).backward()
    lookaheadD.step(mode="fast")

    # discriminator
    lookaheadG.zero_grad()
    loss_G(netD(input)).backward()
    lookaheadG.step(mode="fast")

    # update slow weights if iteration % k == 0
    lookaheadD.step(mode="slow")
    lookaheadG.step(mode="slow")

あとはstep()内の処理としては、mode="fast"の場合にカウンタをインクリメントし、mode="slow"時にk回学習した場合にのみ、線形補間とfast weightsの更新を実行すればいい

shimopino commented 4 years ago
def step(self, closure=None, mode="fast"):
    # update fast weights
    if mode == "fast":
        loss = self.optimizer.step(closure)
        # increment update counter
        for group in self.param_groups:
            group["counter"] += 1

    # update slow weights
    elif mode == "slow":
        for group in self.param_groups:
            # update when iteration & k == 0
            if group["counter"] >= self.k:
                self.update(group)
                # reset update counter
                group["counter"] = 0