Open shimopino opened 4 years ago
理論的な背景を掴むための論文
もとのLookaheadのissueでまとめているコードの場合は、挙動に注意する必要がある。
以下のように元のLookaheadを使用すると、Joint Learning
ではなく論文のAppendix中にあるAlternate Learning
になってしまう。
optG = Adam(netG.parameters(), lr=1e-3, betas=(0.9, 0.999))
optD = Adam(netD.parameters(), lr=1e-3, betas=(0.9, 0.999))
lookaheadG = Lookahead(optG, k=5, alpha=0.5)
lookaheadD = Lookahead(optD, k=5, alpha=0.5)
while iteration < global_iteration:
# discriminator
lookaheadD.zero_grad()
loss_D(netD(input)).backward()
lookaheadD.step()
# discriminator
lookaheadG.zero_grad()
loss_G(netD(input)).backward()
lookaheadG.step()
そこでJoint Learning
を行うためにslow weights
とfast weights
を更新するAPIは分けておくのが無難である。
optG = Adam(netG.parameters(), lr=1e-3, betas=(0.9, 0.999))
optD = Adam(netD.parameters(), lr=1e-3, betas=(0.9, 0.999))
lookaheadG = Lookahead(optG, k=5, alpha=0.5)
lookaheadD = Lookahead(optD, k=5, alpha=0.5)
while iteration < global_iteration:
# discriminator
lookaheadD.zero_grad()
loss_D(netD(input)).backward()
lookaheadD.step(mode="fast")
# discriminator
lookaheadG.zero_grad()
loss_G(netD(input)).backward()
lookaheadG.step(mode="fast")
# update slow weights if iteration % k == 0
lookaheadD.step(mode="slow")
lookaheadG.step(mode="slow")
あとはstep()
内の処理としては、mode="fast"
の場合にカウンタをインクリメントし、mode="slow"
時にk回学習した場合にのみ、線形補間とfast weights
の更新を実行すればいい
def step(self, closure=None, mode="fast"):
# update fast weights
if mode == "fast":
loss = self.optimizer.step(closure)
# increment update counter
for group in self.param_groups:
group["counter"] += 1
# update slow weights
elif mode == "slow":
for group in self.param_groups:
# update when iteration & k == 0
if group["counter"] >= self.k:
self.update(group)
# reset update counter
group["counter"] = 0
論文へのリンク
[arXiv:2006.14567] Taming GANs with Lookahead
著者・所属機関
Tatjana Chavdarova, Matteo Pagliardini, Martin Jaggi, Francois Fleuret
投稿日時(YYYY-MM-DD)
2020-06-25
1. どんなもの?
勾配法を数回行った後で、パラメータを更新されたパラメータを線形補間した値に更新するLookaheadをGANに適用した研究。
本手法により、大きなミニバッチを使用することなく大きな分散に対処することができ、GANの性能を向上させることができると考えられる。
2. 先行研究と比べてどこがすごいの?
GANを勾配法で最適化させる場合、GeneratorとDiscriminatorを同じ目的関数に対してそれぞれを最適化させていき、最終的にナッシュ均衡に達することが理想となる。
課題としてはGANを学習させることは非常に難しく、モード崩壊や学習の発散、ハイパラの変更により性能が大きく変化してしまうことなどが挙げられる。そのためBigGANのようなSOTAなモデルを作成するためには、膨大な計算資源が必要とされているのが現状である。
3. 技術や手法の"キモ"はどこにある?
3.1 Algorithm
GANにLookaheadを適用する場合には、実装上はGeneratorとDiscriminatorにそれぞれLookaheadを適用して更新を行う。計算の流れとしては以下になる。
重要な点としては
slow weights
を更新する際には、GeneratorとDiscriminatorで同時に更新を行っていることである。3.2 Motivating example: the bilinear game
4. どうやって有効だと検証した?
モデルの評価のためにMNISTやCIFAR10などの軽量なデータセットを使用してFIDとISの比較を行っている。この際に使用する最適化手法として、Adamと指数移動平均による更新規則を採用している。
以下ではMNISTで100万回Iterationさせた際の、GeneratorとDiscriminatorのへシアン行列の固有値と、JVFのヤコビアンの固有値である。図(c)を見るとわかるように、LA-GANでは学習が収束した後に、収束点付近で振動していないことがわかる。
次にCIFAR10を使用してBigGANとの比較を行った。興味深い点としてはクラスラベルを使用した教師あり学習の場合、BigGANを越えるFIDを達成していることである。
これはモデルサイズがBigGANが158.3MでありLA-GANが5.1Mであることを考慮すると驚異的であり、バッチサイズもBigGANの2048ではなく、128と現実的な値となっている。
学習の安定性を見てみてもLookaheadを使用することで発散することなく学習が進んでいることが分かる。
5. 議論はあるか?