ML-GSAI / EGSDE

Official implementation for "EGSDE: Unpaired Image-to-Image Translation via Energy-Guided Stochastic Differential Equations" (NIPS 2022)
195 stars 11 forks source link

Error in inference with the model trained on my custom dataset #3

Closed xie-qiang closed 1 year ago

xie-qiang commented 1 year ago

Hello, I'm trying to use my custom dataset to run EGSDE. I have trained a diffusion on my dataset by using guided diffusion, and trained a dse also on my dataset by using run_train_dse.py. But when I run run_EGSDE.py to inference on my test data, some errors occured, I do not know how to solve it.Traceback (most recent call last): File "run_EGSDE_xieqiang.py", line 41, in <module> run_egsde(task) File "run_EGSDE_xieqiang.py", line 37, in run_egsde runner.egsde() File "/code/EGSDE/runners/egsde.py", line 109, in egsde model.load_state_dict(states) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "input_blocks.16.0.in_layers.0.weight", "input_blocks.16.0.in_layers.0.bias", "input_blocks.16.0.in_layers.2.weight", "input_blocks.16.0.in_layers.2.bias", "input_blocks.16.0.emb_layers.1.weight", "input_blocks.16.0.emb_layers.1.bias", "input_blocks.16.0.out_layers.0.weight", "input_blocks.16.0.out_layers.0.bias", "input_blocks.16.0.out_layers.3.weight", "input_blocks.16.0.out_layers.3.bias", "input_blocks.20.0.in_layers.0.weight", "input_blocks.20.0.in_layers.0.bias", "input_blocks.20.0.in_layers.2.weight", "input_blocks.20.0.in_layers.2.bias", "input_blocks.20.0.emb_layers.1.weight", "input_blocks.20.0.emb_layers.1.bias", "input_blocks.20.0.out_layers.0.weight", "input_blocks.20.0.out_layers.0.bias", "input_blocks.20.0.out_layers.3.weight", "input_blocks.20.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_layers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7.2.out_layers.3.bias", "output_blocks.11.1.in_layers.0.weight", "output_blocks.11.1.in_layers.0.bias", "output_blocks.11.1.in_layers.2.weight", "output_blocks.11.1.in_layers.2.bias", "output_blocks.11.1.emb_layers.1.weight", "output_blocks.11.1.emb_layers.1.bias", "output_blocks.11.1.out_layers.0.weight", "output_blocks.11.1.out_layers.0.bias", "output_blocks.11.1.out_layers.3.weight", "output_blocks.11.1.out_layers.3.bias", "output_blocks.15.1.in_layers.0.weight", "output_blocks.15.1.in_layers.0.bias", "output_blocks.15.1.in_layers.2.weight", "output_blocks.15.1.in_layers.2.bias", "output_blocks.15.1.emb_layers.1.weight", "output_blocks.15.1.emb_layers.1.bias", "output_blocks.15.1.out_layers.0.weight", "output_blocks.15.1.out_layers.0.bias", "output_blocks.15.1.out_layers.3.weight", "output_blocks.15.1.out_layers.3.bias", "output_blocks.19.1.in_layers.0.weight", "output_blocks.19.1.in_layers.0.bias", "output_blocks.19.1.in_layers.2.weight", "output_blocks.19.1.in_layers.2.bias", "output_blocks.19.1.emb_layers.1.weight", "output_blocks.19.1.emb_layers.1.bias", "output_blocks.19.1.out_layers.0.weight", "output_blocks.19.1.out_layers.0.bias", "output_blocks.19.1.out_layers.3.weight", "output_blocks.19.1.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.16.0.op.weight", "input_blocks.16.0.op.bias", "input_blocks.20.0.op.weight", "input_blocks.20.0.op.bias", "output_blocks.3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.15.1.conv.weight", "output_blocks.15.1.conv.bias", "output_blocks.19.1.conv.weight", "output_blocks.19.1.conv.bias".

gracezhao1997 commented 1 year ago

The reported error showed the state_dicts of pretrained model and model are different. Can you show the complete error taht what does the state_dict of pretrained model look like? You can also print torch.load(args.ckpt)(the state_dict of pretrained model) and model.state_dict()(the state_dict of model ) to check the difference.