jychoi118 / ilvr_adm

ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)
MIT License
410 stars 51 forks source link

Error in ilvr_sample after loading my trained model #9

Closed vinodrajendran001 closed 2 years ago

vinodrajendran001 commented 2 years ago

Hi,

I trained a model from scratch using my own dataset. After training I ended up with checkpoint files like ema_0.9999_000010.pt, model000010.pt and opt000010.pt.

flags used for training

python train_model.py --data_dir /data1/ --image_size 256 --num_channels 128 --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule cosine --lr 1e-4 --batch_size 4 --save_dir /data2/

I used the checkpoint file ema_0.9999_000010.pt for ilvr sampling but it throwed the following error

flags used for sampling

python src/models/ILVR_GuidedDiffusion/ilvr_sample.py  --attention_resolutions 16 --class_cond False --diffusion_steps 4000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule cosine --num_channels 128 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 100 --model_path /data2/ema_0.9999_000010.pt --base_samples ref_imgs/bdd10k --down_N 32 --range_t 20 --save_dir reports/figures/guided

Error

Logging to reports/figures/guided
creating model...
Traceback (most recent call last):
  File "src/models/ILVR_GuidedDiffusion/ilvr_sample.py", line 134, in <module>
    main()
  File "src/models/ILVR_GuidedDiffusion/ilvr_sample.py", line 49, in main
    model.load_state_dict(
  File "/home/vinod/anaconda3/envs/lsgm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNetModel:
        Missing key(s) in state_dict: "input_blocks.2.0.in_layers.0.weight", "input_blocks.2.0.in_layers.0.bias", "input_blocks.2.0.in_layers.2.weight", "input_blocks.2.0.in_layers.2.bias", "input_blocks.2.0.emb_layers.1.weight", "input_blocks.2.0.emb_layers.1.bias", "input_blocks.2.0.out_layers.0.weight", "input_blocks.2.0.out_layers.0.bias", "input_blocks.2.0.out_layers.3.weight", "input_blocks.2.0.out_layers.3.bias", "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.10.0.in_layers.0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_blocks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers.0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "output_blocks.1.1.in_layers.0.weight", "output_blocks.1.1.in_layers.0.bias", "output_blocks.1.1.in_layers.2.weight", "output_blocks.1.1.in_layers.2.bias", "output_blocks.1.1.emb_layers.1.weight", "output_blocks.1.1.emb_layers.1.bias", "output_blocks.1.1.out_layers.0.weight", "output_blocks.1.1.out_layers.0.bias", "output_blocks.1.1.out_layers.3.weight", "output_blocks.1.1.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.5.1.in_layers.0.weight", "output_blocks.5.1.in_layers.0.bias", "output_blocks.5.1.in_layers.2.weight", "output_blocks.5.1.in_layers.2.bias", "output_blocks.5.1.emb_layers.1.weight", "output_blocks.5.1.emb_layers.1.bias", "output_blocks.5.1.out_layers.0.weight", "output_blocks.5.1.out_layers.0.bias", "output_blocks.5.1.out_layers.3.weight", "output_blocks.5.1.out_layers.3.bias", "output_blocks.7.1.in_layers.0.weight", "output_blocks.7.1.in_layers.0.bias", "output_blocks.7.1.in_layers.2.weight", "output_blocks.7.1.in_layers.2.bias", "output_blocks.7.1.emb_layers.1.weight", "output_blocks.7.1.emb_layers.1.bias", "output_blocks.7.1.out_layers.0.weight", "output_blocks.7.1.out_layers.0.bias", "output_blocks.7.1.out_layers.3.weight", "output_blocks.7.1.out_layers.3.bias", "output_blocks.9.1.in_layers.0.weight", "output_blocks.9.1.in_layers.0.bias", "output_blocks.9.1.in_layers.2.weight", "output_blocks.9.1.in_layers.2.bias", "output_blocks.9.1.emb_layers.1.weight", "output_blocks.9.1.emb_layers.1.bias", "output_blocks.9.1.out_layers.0.weight", "output_blocks.9.1.out_layers.0.bias", "output_blocks.9.1.out_layers.3.weight", "output_blocks.9.1.out_layers.3.bias". 
        Unexpected key(s) in state_dict: "input_blocks.2.0.op.weight", "input_blocks.2.0.op.bias", "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.10.0.op.weight", "input_blocks.10.0.op.bias", "input_blocks.11.1.norm.weight", "input_blocks.11.1.norm.bias", "input_blocks.11.1.qkv.weight", "input_blocks.11.1.qkv.bias", "input_blocks.11.1.proj_out.weight", "input_blocks.11.1.proj_out.bias", "output_blocks.0.1.norm.weight", "output_blocks.0.1.norm.bias", "output_blocks.0.1.qkv.weight", "output_blocks.0.1.qkv.bias", "output_blocks.0.1.proj_out.weight", "output_blocks.0.1.proj_out.bias", "output_blocks.1.2.conv.weight", "output_blocks.1.2.conv.bias", "output_blocks.1.1.norm.weight", "output_blocks.1.1.norm.bias", "output_blocks.1.1.qkv.weight", "output_blocks.1.1.qkv.bias", "output_blocks.1.1.proj_out.weight", "output_blocks.1.1.proj_out.bias", "output_blocks.3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.5.1.conv.weight", "output_blocks.5.1.conv.bias", "output_blocks.7.1.conv.weight", "output_blocks.7.1.conv.bias", "output_blocks.9.1.conv.weight", "output_blocks.9.1.conv.bias". 
        size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]).
        size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

May I know how you generated the 256x256 FFHQ.pt and 256x256 AFHQ-dog.pt because I don't fave any issues while loading these weights?

jychoi118 commented 2 years ago

I trained models with --learn_sigma True . This enables learning variance of reverse process, therefore the output channel is 6 rather than 3. You may try --learn_sigma False when sampling with your models.

vinodrajendran001 commented 2 years ago

Great that works :)

vinodrajendran001 commented 2 years ago

Likewise may I know the complete FLAG details you used for training..peraps if you update in README would be useful.

jychoi118 commented 2 years ago

I trained models with --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True