lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
537 stars 35 forks source link

Discriminator loss converges to zero early in training #16

Open jpfeil opened 9 months ago

jpfeil commented 9 months ago

I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN: https://api.wandb.ai/links/pfeiljx/f7wdueh0

Do you have any suggestions for improving training?

I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?

I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?

After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.

I'm just brainstorming, so feel free to address or ignore any of these comments.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)

with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()
lucidrains commented 9 months ago

@jpfeil can you screenshot the paper section where they propose delaying the discriminator training? (and link the paper too)

lucidrains commented 9 months ago

@jpfeil do you have adversarial_loss_weight greater than 0.? also try another run where your perceptual_loss_weight is 0.1

jpfeil commented 9 months ago

Thanks @lucidrains. I'll try again with those parameters. I saw it in the taming implementation here: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/vqperceptual.py#L51

lucidrains commented 9 months ago

@jpfeil welp.. whatever Robin and Patrick does goes; they are the best in the world.

let me add that

lucidrains commented 9 months ago

@jpfeil ok, added that same functionality here. try removing the learning rate schedule in your next run too, shouldn't need it for something this easy

lucidrains commented 9 months ago

@jpfeil you don't happen to have relatives in Massachusetts, do you?

jpfeil commented 9 months ago

@lucidrains Nice. Let me try it out again. No, I don't have any relatives in Massachusetts. Did you meet someone with the last name Pfeil?

lucidrains commented 9 months ago

yea, I knew someone back in high school with the Pfeil family name. Tragedy struck and they moved away though. You are the second Pfeil I've met!

jpfeil commented 9 months ago

That's amazing. It's not a common name. Sorry to hear about your friend.