LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

mismatched size from pretrained vqgan #32

Closed creatorcao closed 2 months ago

creatorcao commented 3 months ago

Hi! Thank you for your great work! I try to custom train vqgan and to load the checkpoint to mage pixel generator, but I received this error. Do you know why? I trained the vqgan with one gpu and didn't change the config file.

Traceback (most recent call last):
  File "main_mage.py", line 297, in <module>
    main(args)
  File "main_mage.py", line 197, in main
    model = models_mage.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
  File "./rcg/pixel_generator/mage/models_mage.py", line 594, in mage_vit_base_patch16
    model = MaskedGenerativeEncoderViT(
  File "./rcg/pixel_generator/mage/models_mage.py", line 299, in __init__
    self.vqgan = VQModel(ddconfig=vqgan_config.params.ddconfig,
  File "./rcg/pixel_generator/mage/taming/models/vqgan.py", line 28, in __init__
    self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  File "./rcg/pixel_generator/mage/taming/models/vqgan.py", line 50, in init_from_ckpt
    self.load_state_dict(sd, strict=False)
  File ".local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VQModel:
    size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
    size mismatch for encoder.down.4.block.0.nin_shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 1, 1]).
    size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
    size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
    size mismatch for decoder.up.3.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
LTH14 commented 3 months ago

MAGE uses a slightly different VQGAN network architecture than the original VQGAN. You could consider using the original network arch https://github.com/CompVis/taming-transformers/blob/master/taming/models/vqgan.py

creatorcao commented 3 months ago

Thank you for your quick reply! I used the original VQGAN and it worked! But it also needs to change the loss lossconfig: target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator. The loss log starts from 7 on a toy dataset. Do you know if this affects training the MAGE? Can you share your MAGE training log?

[12:34:50.471893] Epoch: [0]  [ 0/68]  eta: 0:43:28  lr: 0.000000  loss: 7.0747 (7.0747)  time: 38.3553  data: 2.9274  max mem: 6478
[12:35:10.637951] Epoch: [0]  [20/68]  eta: 0:02:13  lr: 0.000001  loss: 6.7850 (6.7941)  time: 1.0063  data: 0.0093  max mem: 7586
[12:35:28.299013] Epoch: [0]  [40/68]  eta: 0:00:51  lr: 0.000002  loss: 5.6179 (6.2519)  time: 0.8827  data: 0.0030  max mem: 7586
[12:35:47.316502] Epoch: [0]  [60/68]  eta: 0:00:12  lr: 0.000003  loss: 4.9155 (5.8232)  time: 0.9507  data: 0.0027  max mem: 7586
[12:35:54.612478] Epoch: [0]  [67/68]  eta: 0:00:01  lr: 0.000004  loss: 4.7324 (5.7004)  time: 0.9852  data: 0.0033  max mem: 7586
LTH14 commented 3 months ago

It won't affect the MAGE training. The MAGE training loss is not related to this VQGAN lossconfig. This VQGAN lossconfig is used to specify the VQGAN training loss, which is used only in the VQGAN training.

MAGE's training loss will be around 5.7 on ImageNet. However, depending on the dataset, the training loss can vary a lot -- some datasets are easier while others are harder. Your training loss looks reasonable. I typically look at the generation performance to see whether my training works or not instead of the training loss.

creatorcao commented 3 months ago

Great. Thanks a lot. 👍

creatorcao commented 2 months ago

您好,我在evaluate MAGE的时候碰到了如下的错误。我使用了taming repo的vqgan.py去tokenize自己的数据,添加了那个lossconfig在VQModel,但是这个error说load checkpoint后mismatched size。您可以解答一下吗?是因为 前面有人说的单卡训练VQGAN后load state_dict出现的错误(因为保存的权重中没有module),还是因为MAGE的VQGAN与original vqgan.py的结构不同出现的问题呢?

Traceback (most recent call last): File "/gpfs/space/home/etais/hpc_ping/rcg/main_mage.py", line 298, in <module> main(args) File "/gpfs/space/home/etais/hpc_ping/rcg/main_mage.py", line 198, in main model = models_mage.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std, File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/models_mage.py", line 595, in mage_vit_base_patch16 model = MaskedGenerativeEncoderViT( File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/models_mage.py", line 299, in __init__ self.vqgan = VQModel(ddconfig=vqgan_config.params.ddconfig, File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/taming/models/vqgan.py", line 50, in __init__ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/taming/models/vqgan.py", line 66, in init_from_ckpt self.load_state_dict(sd, strict=False) File "/gpfs/space/home/etais/hpc_ping/.conda/envs/mages/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VQModel: size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]). size mismatch for encoder.down.4.block.0.nin_shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 1, 1]). size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]). size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]). size mismatch for decoder.up.3.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).

LTH14 commented 2 months ago

这是由于MAGE的vqgan与original vqgan.py的结构不同

creatorcao commented 2 months ago

谢谢解答! orginal vqgan.py里面有self.loss = instantiate_from_config(lossconfig),于是我在pixel_generator/mage/models_mage.py 那里load pretrained VQGAN添加了lossconfig=vqgan_config.params.lossconfig。似乎MAGE vqgan和 original vqgan.py只有这个loss有变化,但是这样就得到了上面的错误。您可以教我怎么更改吗?

LTH14 commented 2 months ago

我记得MAGE vqgan的变化主要是encoder decoder的网络结构有一些变化(比如没有attention)。由于VQGAN loss在MAGE training里不需要,我建议你可以把两边的都去掉(在训练完VQGAN后从checkpoint里扔掉)。

creatorcao commented 2 months ago

我把两边的VQGAN(pretrained VQGAN checkpoint和MAGE的VQGAN)打印出来后,去掉pretrained VQGAN checkpoint多的结构,比如loss和不同的encoder, decoder, 但是仍然得到上面一样的错误。两边的config也是一样的。还可能是什么原因呢?这样去掉权重的一些结构会影响生成结果吗?

LTH14 commented 2 months ago

这个报错是因为同名的层在MAGE的VQGAN和原始VQGAN的结构不一样,。既然你有自己训练的VQGAN checkpoint,我建议你把MAGE里的VQGAN文件直接替换成原始VQGAN。用MAGE的VQGAN文件load原始的VQGAN是不行的。