SusungHong / Self-Attention-Guidance

The implementation of the paper "Improving Sample Quality of Diffusion Models Using Self-Attention Guidance" (ICCV`23)
MIT License
107 stars 14 forks source link

Error while loading checkpoints #10

Open arnobanu opened 8 months ago

arnobanu commented 8 months ago

Thank you very much for the contribution and for releasing the code. I wanted to run the command for lsun dataset. SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250" MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" SAG_FLAGS="--guide_scale 1.05 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False" mpiexec -n $NUM_GPUS python image_sample.py $SAG_FLAGS $MODEL_FLAGS --model_path models/lsun_cat.pt $SAMPLE_FLAGS

but having some missing keys and unexpected keys Exception has occurred: RuntimeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "label_emb.weight", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias", "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.1.norm.weight", "input_blocks.8.1.norm.bias", "input_blocks.8.1.qkv.weight", "input_blocks.8.1.qkv.bias", "input_blocks.8.1.proj_out.weight", "input_blocks.8.1.proj_out.bias", "input_blocks.10.0.skip_connection.weight", "input_blocks.10.0.skip_connection.bias". Unexpected key(s) in state_dict: "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", ....

I have downloaded the model lsun_cat.pt from the link given and placed it in the models folder.