LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

the model mismatch for mege 'model.load state_dict(checkpoint['model']),when I code with viz_rcg.ipynb #13

Open Yisher opened 8 months ago

Yisher commented 8 months ago

Hello!thank you for your great work. I trained rdm.pth from main_rdm.py,and trained mage.pth from main_mage.py, when I want to visualize the genereation, I encount this problem: RuntimeError Traceback (most recent call last) Cell In[9], line 2 1 checkpoint = torch.load(os.path.join('output/checkpoint-last.pth'), map_location='cpu') ----> 2 model.load_state_dict(checkpoint['model'], strict=True) 3 model.cuda() 4 _ = model.eval()

RuntimeError: Error(s) in loading state_dict for MaskedGenerativeEncoderViT: size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 1024]). size mismatch for pos_embed: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]). size mismatch for mask_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 1024]). size mismatch for decoder_pos_embed: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]). size mismatch for decoder_pos_embed_learned: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]). size mismatch for token_emb.word_embeddings.weight: copying a param with shape torch.Size([2025, 768]) from checkpoint, the shape in current model is torch.Size([2025, 1024]). size mismatch for token_emb.position_embeddings.weight: copying a param with shape torch.Size([257, 768]) from checkpoint, the shape in current model is torch.Size([257, 1024]). ... size mismatch for decoder_pred.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([768, 1024]). size mismatch for mlm_layer.fc.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for mlm_layer.fc.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for mlm_layer.ln.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for mlm_layer.ln.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]). I can't understand why ,when I come with this problem. I used my own dataset for training and use no distributed training

Yisher commented 8 months ago

and my training mode is base mode,not the large or huge one ,that's the detail: Model = MaskedGenerativeEncoderViT( (token_emb): BertEmbeddings( (word_embeddings): Embedding(2025, 768) (position_embeddings): Embedding(257, 768) (LayerNorm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (patch_embed): PatchEmbed( (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) ) (blocks): ModuleList( (0-11): 12 x Block( (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (attn): Attention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (attn_drop): Dropout(p=0.1, inplace=False) (proj): Linear(in_features=768, out_features=768, bias=True) (proj_drop): Dropout(p=0.1, inplace=False) ) (drop_path): Identity() (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=3072, out_features=768, bias=True) (drop): Dropout(p=0.1, inplace=False) ) ) ) (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (decoder_embed): Linear(in_features=768, out_features=768, bias=True) (decoder_blocks): ModuleList( (0-7): 8 x Block( (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (attn): Attention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (attn_drop): Dropout(p=0.1, inplace=False) (proj): Linear(in_features=768, out_features=768, bias=True) (proj_drop): Dropout(p=0.1, inplace=False) ) (drop_path): Identity() (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=3072, out_features=768, bias=True) (drop): Dropout(p=0.1, inplace=False) ) ) ) (decoder_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (decoder_pred): Linear(in_features=768, out_features=768, bias=True) (mlm_layer): MlmLayer( (fc): Linear(in_features=768, out_features=768, bias=True) (gelu): GELU(approximate='none') (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) (criterion): LabelSmoothingCrossEntropy() (pretrained_encoder): VisionTransformerMoCo( (patch_embed): PatchEmbed( (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) ) (pos_drop): Dropout(p=0.0, inplace=False) (blocks): ModuleList( (0-11): 12 x Block( (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (attn): Attention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=768, out_features=768, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) ) (drop_path): Identity() (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=3072, out_features=768, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) ) (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (head): Sequential( (0): Linear(in_features=768, out_features=4096, bias=False) (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Linear(in_features=4096, out_features=4096, bias=False) (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): Linear(in_features=4096, out_features=256, bias=False) (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True) ) ) (vqgan): VQModel( (encoder): Encoder( (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (down): ModuleList( (0-1): 2 x Module( (block): ModuleList( (0-1): 2 x ResnetBlock( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (downsample): Downsample() ) (2): Module( (block): ModuleList( (0): ResnetBlock( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (nin_shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) (1): ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (downsample): Downsample() ) (3): Module( (block): ModuleList( (0-1): 2 x ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (downsample): Downsample() ) (4): Module( (block): ModuleList( (0): ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (nin_shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) ) (1): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) ) ) (mid): Module( (block_1): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (block_2): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (norm_out): GroupNorm(32, 512, eps=1e-06, affine=True) (conv_out): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) ) (decoder): Decoder( (conv_in): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (mid): Module( (block_1): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (block_2): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (up): ModuleList( (0): Module( (block): ModuleList( (0-1): 2 x ResnetBlock( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) ) (1): Module( (block): ModuleList( (0): ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (nin_shortcut): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) ) (1): ResnetBlock( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (upsample): Upsample( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (2): Module( (block): ModuleList( (0-1): 2 x ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (upsample): Upsample( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (3): Module( (block): ModuleList( (0): ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (nin_shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) (1): ResnetBlock( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (upsample): Upsample( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (4): Module( (block): ModuleList( (0-1): 2 x ResnetBlock( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (upsample): Upsample( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (norm_out): GroupNorm(32, 128, eps=1e-06, affine=True) (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (quantize): VectorQuantizer2( (embedding): Embedding(1024, 256) ) )

Yisher commented 8 months ago

I changed the ipynb's code from "model = models_mage.mage_vit_large_patch16" to model = models_mage.mage_vit_base_patch16,because I trained model in base mode ,now the error change into a new one:

RuntimeError Traceback (most recent call last) Cell In[13], line 2 1 checkpoint = torch.load(os.path.join('output/checkpoint-last.pth'), map_location='cpu') ----> 2 model.load_statedict(checkpoint['model'], strict=True) 3 model.cuda() 4 = model.eval() RuntimeError: Error(s) in loading state_dict for MaskedGenerativeEncoderViT: Missing key(s) in state_dict: "latent_prior_proj.weight", "latent_prior_proj.bias". It seemed that there are only two keys missing ,but I dont know why

LTH14 commented 8 months ago

Did you set use_rep when training MAGE? It seems there's no latent_prior_proj in your trained model, which should be initialized here https://github.com/LTH14/rcg/blob/main/pixel_generator/mage/models_mage.py#L196-L198

Yisher commented 8 months ago

I'll try it later when my gpu available,thx for reply! what's more ,because I train my model in windows system and just use only one gpu4070 ,I can't init with the torch.distributed.launch ,so I have always meet the problem of the function-code in misc.py,that is concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. Warning : torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.oneslike(tensor) for in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output It will tell me there are "RuntimeError: Default process group has not been initialized, please make sure to call initprocess"when the code running into "torch.distributed.get_world_size()" and " torch.distributed.all_gather". so I change the code as: tensors_gather = [torch.oneslike(tensor) for in range(1)] output = torch.cat(tensors_gather, dim=0) return output then the code runs, do you think my change of the code is accepetable?I am afraid it will destroy the structure of layers.

LTH14 commented 8 months ago

No -- you should simply comment out the concat_all_gather line. Your modification will return an output full of 1, as you use tensors_gather = [torch.oneslike(tensor) for in range(1)].

Yisher commented 8 months ago

ok, I think that will be fine,but how to comment it?because there are code in https://github.com/LTH14/rcg/blob/4f1c32fe1378f9d7ec39727558ddf8ce2e9a8c9a/engine_mage.py#L107 when I comment it , the new error is

the new error is File "D:\DeepLearning\rcg-main\main_mage.py", line 297, in <module> main(args) File "D:\DeepLearning\rcg-main\main_mage.py", line 270, in main gen_img(model, args, epoch, batch_size=16, log_writer=log_writer, cfg=0) File "D:\DeepLearning\rcg-main\engine_mage.py", line 102, in gen_img gen_images_batch, _ = model(None, None, ^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepLearning\rcg-main\pixel_generator\mage\models_mage.py", line 455, in forward return self.gen_image(bsz, num_iter, choice_temperature, sampled_rep, rdm_steps, eta, cfg, class_label_gen) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepLearning\rcg-main\pixel_generator\mage\models_mage.py", line 533, in gen_image input_embeddings[:, 0] = self.latent_prior_proj(sampled_rep) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x256 and 768x768), (in that error case i've set use_pre)

LTH14 commented 8 months ago

You just comment it out and it should be fine. This error is caused by use_rep, not by commenting. You need to set --rep_dim=256. Please follow the provided command in Readme and its arguments


python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --node_rank=0 \
main_mage.py \
--pretrained_enc_arch mocov3_vit_base \
--pretrained_enc_path pretrained_enc_ckpts/mocov3/vitb.pth.tar --rep_drop_prob 0.1 \
--use_rep --rep_dim 256 --pretrained_enc_withproj --pretrained_enc_proj_dim 256 \
--pretrained_rdm_cfg ${RDM_CFG_PATH} --pretrained_rdm_ckpt ${RDM_CKPT_PATH} \
--rdm_steps 250 --eta 1.0 --temp 6.0 --num_iter 20 --num_images 50000 --cfg 0.0 \
--batch_size 64 --input_size 256 \
--model mage_vit_base_patch16 \
--mask_ratio_min 0.5 --mask_ratio_max 1.0 --mask_ratio_mu 0.75 --mask_ratio_std 0.25 \
--epochs 200 \
--warmup_epochs 10 \
--blr 1.5e-4 --weight_decay 0.05 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214
Yisher commented 8 months ago

I feel sorry for repeatedly asking questions,when I follow the args ,it works,the code is running now. the--pretrained_enc_withproj one is also important. when I get the new output tomorrow ,i'll update whether the result looks good here. thank you for your reply!

LTH14 commented 8 months ago

No worries -- please let me know if you encounter other problems.