voletiv / mcvd-pytorch

Official implementation of MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation (https://arxiv.org/abs/2205.09853)
MIT License
331 stars 26 forks source link

Finetuning model from pretrained checkpoint leads to size mismatch #18

Closed eiriksteen closed 1 year ago

eiriksteen commented 1 year ago

There seems to be a discrepancy between the BAIR config files and the parameter shapes in pretrained checkpoints from the provided links. I have tried all BAIR configs, and used the checkpoint from bair64_big192_5c1_pmask50_unetm-20230211T213918Z-002. I run this training command: "CUDA_VISIBLE_DEVICES=0 python main.py --config configs/bair.yml --data_path datasets/echo_h5 --resume_training --exp bair --ni". This is the terminal output:

size mismatch for module.unet.all_modules.39.Conv_1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for module.unet.all_modules.39.Conv_2.weight: copying a param with shape torch.Size([192, 384, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1, 1]). size mismatch for module.unet.all_modules.39.Conv_2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([64]).

AlexiaJM commented 1 year ago

Hi @eiriksteen,

The config files sometimes don't have all the exact hyperparameters, it's best to look at https://github.com/voletiv/mcvd-pytorch/blob/master/example_scripts/final/training_scripts.sh.

Any parameter added in the command prompt will override the default. So for example https://github.com/voletiv/mcvd-pytorch/blob/451da2eb635bad50da6a7c03b443a34c6eb08b3a/example_scripts/final/training_scripts.sh#L128 will use the BAIR config but overwrite these: training.snapshot_freq=50000 model.ngf=192 model.n_head_channels=192 sampling.num_frames_pred=28 data.num_frames=5 data.num_frames_cond=1 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore.

eiriksteen commented 1 year ago

Thank you for the help and quick response!