lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
540 stars 35 forks source link

Reconstruction image is always a solid color #12

Closed jpfeil closed 10 months ago

jpfeil commented 10 months ago

Hello,

I've been working on training this on the imagenet data, but I'm concerned I'm doing something wrong because the reconstructions are always a solid color. I haven't trained it very long ~1500 steps (batch size 10), but I just wanted to check if this is expected.

1300 steps: image

1200 steps: image

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, 
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/imagenet/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 8,
    num_train_steps = 1_000_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=1e-4, # From the paper
    accelerate_kwargs={"split_batches": True, "mixed_precision": 'fp16'},
    random_split_seed=171,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={}
)

trainer.train()
lucidrains commented 10 months ago

@jpfeil could you retry with fp32? and train until 5000 steps? also, grad accum of 4-6 is sufficient (32-64 effective batch size)

lucidrains commented 10 months ago

@jpfeil also share your training curve, try out wandb's report feature for easy sharing

jpfeil commented 10 months ago

Thanks @lucidrains I'll let you know when the wandb report is ready.

jpfeil commented 10 months ago

@lucidrains This was run on 0.1.24, so I'm going to pull the latest version and retry. The loss was slowly improving, but around step 1000, the loss became nan. The only change I've made is I added a cosine schedule with warmup. I'm also still using bf16, so I'll change that in the next run.

https://api.wandb.ai/links/pfeiljx/p2x7x2x2

jpfeil commented 10 months ago

Hi @lucidrains

I ran it using fp32 and trained for 5000 steps, but I did not see any improvement.

https://api.wandb.ai/links/pfeiljx/8kqeyypi

Let me know if you have any suggestions.

jpfeil commented 10 months ago

@lucidrains I ran the fashion mnist data last night and the model was able to converge:

https://api.wandb.ai/links/pfeiljx/udspvdgu

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d %H:%M:%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, # From the paper,
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 5,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)

with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 {RUNTIME}'):
    trainer.train()
lucidrains commented 10 months ago

@jpfeil @jacobpfeil i think this repository should support pretraining with 2d conv layers, and then a way to convert it to 3d for video. but let me meditate on the simplest way to achieve this

jpfeil commented 10 months ago

Thanks @lucidrains. Let me know if I can help run some tests. I have access to a few A100 GPUs.

lucidrains commented 10 months ago

@jpfeil sounds good

let me think about this for a few days or the code will come out wrong

measure twice cut once kinda thing

jpfeil commented 10 months ago

@lucidrains After looking at the FashionMNIST results, it looks like the discriminator collapsed to zero loss. So, I think the learning stopped prematurely. I'm also not getting good reconstructions.

sampled 17

For VQ-GAN, I've read that the autoencoder needs a couple epochs to generate good images before the discriminator starts. Is there a way to do that here?

lucidrains commented 10 months ago

@jpfeil yea i could add that, but only if need be

what happens if you set adversarial_loss_weight to 0.

it really should converge for fashion mnist quite quickly, even without the GAN system

jpfeil commented 10 months ago

I get an assertion error because self.has_gan attribute gets set to False. Is it okay to override that assertion?

lucidrains commented 10 months ago

@jpfeil could you point to the line number?

could you also give 0.1.29 a quick try? may be a bug but not entirely sure

lucidrains commented 10 months ago

@jpfeil oh nvm, yes i see it. we should be able to turn off adversarial loss, let me fix

lucidrains commented 10 months ago

@jpfeil try 0.1.31 with use_gan = False on the VideoTokenizer

jpfeil commented 10 months ago

Woops. My Tokenizer change wasn't saved. Running now...

lucidrains commented 10 months ago

@jpfeil give the imagenet run another try

there may have been a bug with how I zeroed the gradients a few patches ago

jpfeil commented 10 months ago

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

coolbunnyx commented 9 months ago

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

Hi @jpfeil Do you mind sharing how did you end up solving it? I run into the same issue https://github.com/lucidrains/magvit2-pytorch/issues/25

jpfeil commented 8 months ago

Hi @coolbunnyx,

Sorry for the delay. I think you already solved it, but I was able to get good reconstruction after training for longer.