lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
565 stars 34 forks source link

Running with GAN raises RuntimeError #15

Closed jpfeil closed 1 year ago

jpfeil commented 1 year ago

v0.1.32 works without the GAN, but I get an error when using the GAN again.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "bf16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)

with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()
Traceback (most recent call last):
  File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
    trainer.train()
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
    self.train_step(dl_iter)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
    loss, loss_breakdown = self.model(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1832, in forward
    norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.grad_layer_wrt_loss) at 0x7fff42659900>", line 50, in grad_layer_wrt_loss
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 129, in grad_layer_wrt_loss
    return torch_grad(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/autograd/__init__.py", line 394, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
lucidrains commented 1 year ago

@jpfeil nice! ok, so i forgot to handle the greyscale edge case for the perceptual loss

try 0.1.33?

jpfeil commented 1 year ago

Thanks, @lucidrains! I tried 0.1.33 but I got this error

Traceback (most recent call last):
 3   File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
 4     trainer.train()
 5   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
 6     self.train_step(dl_iter)
 7   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
 8     loss, loss_breakdown = self.model(
 9   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
10     return self._call_impl(*args, **kwargs)
11   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
12     return forward_call(*args, **kwargs)
13   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
14     return model_forward(*args, **kwargs)
15   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
16     return convert_to_fp32(self.model_forward(*args, **kwargs))
17   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
18     return func(*args, **kwargs)
19   File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
20   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1849, in forward
21     fake_logits = self.discr(recon_video_frames)
22   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
23     return self._call_impl(*args, **kwargs)
24   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
25     return forward_call(*args, **kwargs)
26   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 641, in forward
27     x = block(x)
28   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
29     return self._call_impl(*args, **kwargs)
30   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
31     return forward_call(*args, **kwargs)
32   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 545, in forward
33     res = self.conv_res(x)
34   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
35     return self._call_impl(*args, **kwargs)
36   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
37     return forward_call(*args, **kwargs)
38   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
39     return self._conv_forward(input, self.weight, self.bias)
40   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
41     return F.conv2d(input, weight, bias, self.stride,
42 RuntimeError: Given groups=1, weight of size [512, 3, 1, 1], expected input[10, 1, 32, 32] to have 3 channels, but got 1 channels instead 
lucidrains commented 1 year ago

@jpfeil try 0.1.35?

lucidrains commented 1 year ago

@jpfeil you are seeing reconstructions being correct without adversarial training right?

jpfeil commented 1 year ago

Yeah, I found I needed to increase the codebook size and I can get reconstructions looking decent without adversarial training. The details aren't there yet, but the general features are encoded.

sampled 47

lucidrains commented 1 year ago

nice! thank you!

jpfeil commented 1 year ago

The discriminator code runs. The discriminator loss converges to zero, but I'll open this in a different issue.

lucidrains commented 1 year ago

have you set the adversarial loss weight to be greater than 0?