Closed jpfeil closed 1 year ago
@jpfeil nice! ok, so i forgot to handle the greyscale edge case for the perceptual loss
try 0.1.33?
Thanks, @lucidrains! I tried 0.1.33 but I got this error
Traceback (most recent call last):
3 File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
4 trainer.train()
5 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
6 self.train_step(dl_iter)
7 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
8 loss, loss_breakdown = self.model(
9 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
10 return self._call_impl(*args, **kwargs)
11 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
12 return forward_call(*args, **kwargs)
13 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
14 return model_forward(*args, **kwargs)
15 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
16 return convert_to_fp32(self.model_forward(*args, **kwargs))
17 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
18 return func(*args, **kwargs)
19 File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
20 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1849, in forward
21 fake_logits = self.discr(recon_video_frames)
22 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
23 return self._call_impl(*args, **kwargs)
24 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
25 return forward_call(*args, **kwargs)
26 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 641, in forward
27 x = block(x)
28 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
29 return self._call_impl(*args, **kwargs)
30 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
31 return forward_call(*args, **kwargs)
32 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 545, in forward
33 res = self.conv_res(x)
34 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
35 return self._call_impl(*args, **kwargs)
36 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
37 return forward_call(*args, **kwargs)
38 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
39 return self._conv_forward(input, self.weight, self.bias)
40 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
41 return F.conv2d(input, weight, bias, self.stride,
42 RuntimeError: Given groups=1, weight of size [512, 3, 1, 1], expected input[10, 1, 32, 32] to have 3 channels, but got 1 channels instead
@jpfeil try 0.1.35?
@jpfeil you are seeing reconstructions being correct without adversarial training right?
Yeah, I found I needed to increase the codebook size and I can get reconstructions looking decent without adversarial training. The details aren't there yet, but the general features are encoded.
nice! thank you!
The discriminator code runs. The discriminator loss converges to zero, but I'll open this in a different issue.
have you set the adversarial loss weight to be greater than 0?
v0.1.32 works without the GAN, but I get an error when using the GAN again.