dome272 / VQGAN-pytorch

Pytorch implementation of VQGAN (Taming Transformers for High-Resolution Image Synthesis) (https://arxiv.org/pdf/2012.09841.pdf)
MIT License
443 stars 73 forks source link

Missing projection of Attention values in non-local block and incorrect position of the Upsampling block in the decoder #23

Closed hgupta01 closed 1 month ago

hgupta01 commented 1 month ago

Hi,

Thanks for the clean implementation of the VQGAN. It is really helpful for me.

I found one bug in the non-local block (helper.py, line 112); it is missing the projection of attention before being added to the input.

This might be the cause of the weird artifacts in the image.

SnakeOnex commented 1 month ago

Hello,

have you tried it fixing it and seeing if results of this repo improve? I tried adding it but results are still garbage.

hgupta01 commented 1 month ago

Hi,

I am currently "re-writing" the repo for my work, so I haven't tried it yet. However, there is one more bug in the placement of the Upsampling block.

The Upsampling block needs to be moved from position 28 to 10. This seems to be a significant bug that might be causing the weird effect, as the upsampling is done just at the end in the current version, and the artifacts due to upsampling may not "dissolve" properly.

Decoder2( (model): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 512, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (2): NonLocalBlock( (gn): GroupNorm(32, 512, eps=1e-06, affine=True) (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) ) (3): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 512, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (4): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 512, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (5): NonLocalBlock( (gn): GroupNorm(32, 512, eps=1e-06, affine=True) (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) ) (6): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 512, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (7): NonLocalBlock( (gn): GroupNorm(32, 512, eps=1e-06, affine=True) (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) ) (8): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 512, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (9): NonLocalBlock( (gn): GroupNorm(32, 512, eps=1e-06, affine=True) (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) ) (10): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 512, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (channel_up): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) ) (11): NonLocalBlock( (gn): GroupNorm(32, 256, eps=1e-06, affine=True) (q): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) ) (12): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (13): NonLocalBlock( (gn): GroupNorm(32, 256, eps=1e-06, affine=True) (q): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) ) (14): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (15): NonLocalBlock( (gn): GroupNorm(32, 256, eps=1e-06, affine=True) (q): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (k): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (v): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (proj_out): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) ) (16): UpSampleBlock( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (17): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (18): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (19): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 256, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (20): UpSampleBlock( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (21): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 256, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (channel_up): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) ) (22): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 128, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (23): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 128, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (24): UpSampleBlock( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (25): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 128, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (26): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 128, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (27): ResidualBlock( (block): Sequential( (0): GroupNorm(32, 128, eps=1e-06, affine=True) (1): Swish() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): GroupNorm(32, 128, eps=1e-06, affine=True) (4): Swish() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (28): UpSampleBlock( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (29): GroupNorm(32, 128, eps=1e-06, affine=True) (30): Swish() (31): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) )

SnakeOnex commented 1 month ago

Hello, very cool. Am also trying to rewrite this repo so it works. I am currently going the approach of from this issue.

I replaced the model, the perpcetual loss and removed the GAN stuff for now and launched three training runs for each version. Hoping the contours / artifacts will disappear. However I think even after replacing the VQGan model with the code from orig paper the contour were still there, so might be some other issue. However, I only just started on this, so might have made some stupid mistake.

I can try to run your change with the Upsample block and see if things get better.

SnakeOnex commented 1 month ago

Looking at the upsampling thing you mentioned, i don't think i understand. The layers seem to be at the right places as they are in the original paper.

image

hgupta01 commented 1 month ago

Hi,

The implementation is slightly different compared to the paper. According to the paper, there should be M blocks of {Residual Block -> Upsample block}. However the implementation is different.

In taming-transformer implementation, the first (M-1) blocks are {Residual Block -> Upsample block} and the last block (m=M) is only {Residual Block}.

In this repo, the first block is only {Residual Block} and next M-1 blocks {Residual Block -> Upsample block}. You can verify this by checking the condition at line 26 in decoder.py; the upsampling block is not created for the first block (when i=0).

I hope this helps. I will update the comment with my results later.

SnakeOnex commented 1 month ago

Thanks, will look into it.

In the meantime, I have tried replacing the LPIPS and VQGAN modules by the ones in the original paper repository. And removed the GAN stuff to simplify it and got these results after 170 epochs.

170_400

aa1234241 commented 1 month ago

283_1000 Hello, I've uploaded a bug-fixed version. I hope it can help.

hgupta01 commented 1 month ago

Thanks, I also found the same errors as you. After fixing the code, the results look good. I will add the normalization as you suggested in the codebook and try it. I think for faster training, you need to change the learning rate as learning_rate = accumulate_grad_batches ngpu bs * base_lr based the original implementation.

I also plan to replace the codebook with https://github.com/lucidrains/vector-quantize-pytorch which looks more stable.