ML-GSAI / EGSDE

Official implementation for "EGSDE: Unpaired Image-to-Image Translation via Energy-Guided Stochastic Differential Equations" (NIPS 2022)
195 stars 11 forks source link

Hello, I am trying to train the model from scratch without loading the pre-trained model, how should I do it. #1

Closed LinjieFu-U closed 1 year ago

LinjieFu-U commented 1 year ago

Hello, I am trying to train the model from scratch without loading the pre-trained model, what should I do it.

gracezhao1997 commented 1 year ago

Hi @flj19951219, for training domain-specific extractor, you can run run_train_dse.py, where data_path is the data path for source and target image and num class is the number of domains. For default, the code use the pre-trained classifier from guided diffusion as initial weight. You can set "pretrained" False in create_argparser function in run_train_dse.py and it will train a classifier from scratch. In this way, you may need to increase the training iterations.

LinjieFu-U commented 1 year ago

Thank you for your reply. I try to train my own data using run_train_dse.py and get two models, opt.pt and model.pt. When loading the two models in run_EGSDE.py and extrapolating, I set ckpt in args.py to the path of opt.pt and dsepath to the path of model.pt, and then I get a parameter mismatch error. RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "time_embed.0.weight", "time_embed.0.bias", "time_embed.2.weight", "time_embed.2.bias", "input_blocks.0.0.weight", "input_blocks.0.0.bias", "input_blocks.1.0.in_layers.0.weight", "input_blocks.1.0.in_layers.0.bias", "input_blocks.1.0.in_layers.2.weight", "input_blocks.1.0.in_layers.2.bias", "input_blocks.1.0.emb_layers.1.weight", "input_blocks.1.0.emb_layers.1.bias", "input_blocks.1.0.out_layers.0.weight", "input_blocks.1.0.out_layers.0.bias", "input_blocks.1.0.out_layers.3.weight", "input_blocks.1.0.out_layers.3.bias", "input_blocks.2.0.in_layers.0.weight", "input_blocks.2.0.in_layers.0.bias", "input_blocks.2.0.in_layers.2.weight", "input_blocks.2.0.in_layers.2.bias", "input_blocks.2.0.emb_layers.1.weight", "input_blocks.2.0.emb_layers.1.bias", "input_blocks.2.0.out_layers.0.weight", "input_blocks.2.0.out_layers.0.bias", "input_blocks.2.0.out_layers.3.weight", "input_blocks.2.0.out_layers.3.bias", "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.0.in_layers.0.weight", "input_blocks.5.0.in_layers.0.bias", "input_blocks.5.0.in_layers.2.weight", "input_blocks.5.0.in_layers.2.bias", "input_blocks.5.0.emb_layers.1.weight", "input_blocks.5.0.emb_layers.1.bias", "input_blocks.5.0.out_layers.0.weight", "input_blocks.5.0.out_layers.0.bias", "input_blocks.5.0.out_layers.3.weight", "input_blocks.5.0.out_layers.3.bias", "input_blocks.5.0.skip_connection.weight", "input_blocks.5.0.skip_connection.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.7.0.in_layers.0.weight", "input_blocks.7.0.in_layers.0.bias", "input_blocks.7.0.in_layers.2.weight", "input_blocks.7.0.in_layers.2.bias", "input_blocks.7.0.emb_layers.1.weight", "input_blocks.7.0.emb_layers.1.bias", "input_blocks.7.0.out_layers.0.weight", "input_blocks.7.0.out_layers.0.bias", "input_blocks.7.0.out_layers.3.weight", "input_blocks.7.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.9.0.skip_connection.weight", "input_blocks.9.0.skip_connection.bias", "input_blocks.9.1.norm.weight", "input_blocks.9.1.norm.bias", "input_blocks.9.1.qkv.weight", "input_blocks.9.1.qkv.bias", "input_blocks.9.1.proj_out.weight", "input_blocks.9.1.proj_out.bias", "input_blocks.10.0.in_layers.0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_blocks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers.0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "input_blocks.11.0.in_layers.0.weight", "input_blocks.11.0.in_layers.0.bias", "input_blocks.11.0.in_layers.2.weight", "input_blocks.11.0.in_layers.2.bias", "input_blocks.11.0.emb_layers.1.weight", "input_blocks.11.0.emb_layers.1.bias", "input_blocks.11.0.out_layers.0.weight", "input_blocks.11.0.out_layers.0.bias", "input_blocks.11.0.out_layers.3.weight", "input_blocks.11.0.out_layers.3.bias", "middle_block.0.in_layers.0.weight", "middle_block.0.in_layers.0.bias", "middle_block.0.in_layers.2.weight", "middle_block.0.in_layers.2.bias", "middle_block.0.emb_layers.1.weight", "middle_block.0.emb_layers.1.bias", "middle_block.0.out_layers.0.weight", "middle_block.0.out_layers.0.bias", "middle_block.0.out_layers.3.weight", "middle_block.0.out_layers.3.bias", "middle_block.1.norm.weight", "middle_block.1.norm.bias", "middle_block.1.qkv.weight", "middle_block.1.qkv.bias", "middle_block.1.proj_out.weight", "middle_block.1.proj_out.bias", "middle_block.2.in_layers.0.weight", "middle_block.2.in_layers.0.bias", "middle_block.2.in_layers.2.weight", "middle_block.2.in_layers.2.bias", "middle_block.2.emb_layers.1.weight", "middle_block.2.emb_layers.1.bias", "middle_block.2.out_layers.0.weight", "middle_block.2.out_layers.0.bias", "middle_block.2.out_layers.3.weight", "middle_block.2.out_layers.3.bias", "output_blocks.0.0.in_layers.0.weight", "output_blocks.0.0.in_layers.0.bias", "output_blocks.0.0.in_layers.2.weight", "output_blocks.0.0.in_layers.2.bias", "output_blocks.0.0.emb_layers.1.weight", "output_blocks.0.0.emb_layers.1.bias", "output_blocks.0.0.out_layers.0.weight", "output_blocks.0.0.out_layers.0.bias", "output_blocks.0.0.out_layers.3.weight", "output_blocks.0.0.out_layers.3.bias", "output_blocks.0.0.skip_connection.weight", "output_blocks.0.0.skip_connection.bias", "output_blocks.1.0.in_layers.0.weight", "output_blocks.1.0.in_layers.0.bias", "output_blocks.1.0.in_layers.2.weight", "output_blocks.1.0.in_layers.2.bias", "output_blocks.1.0.emb_layers.1.weight", "output_blocks.1.0.emb_layers.1.bias", "output_blocks.1.0.out_layers.0.weight", "output_blocks.1.0.out_layers.0.bias", "output_blocks.1.0.out_layers.3.weight", "output_blocks.1.0.out_layers.3.bias", "output_blocks.1.0.skip_connection.weight", "output_blocks.1.0.skip_connection.bias", "output_blocks.1.1.in_layers.0.weight", "output_blocks.1.1.in_layers.0.bias", "output_blocks.1.1.in_layers.2.weight", "output_blocks.1.1.in_layers.2.bias", "output_blocks.1.1.emb_layers.1.weight", "output_blocks.1.1.emb_layers.1.bias", "output_blocks.1.1.out_layers.0.weight", "output_blocks.1.1.out_layers.0.bias", "output_blocks.1.1.out_layers.3.weight", "output_blocks.1.1.out_layers.3.bias", "output_blocks.2.0.in_layers.0.weight", "output_blocks.2.0.in_layers.0.bias", "output_blocks.2.0.in_layers.2.weight", "output_blocks.2.0.in_layers.2.bias", "output_blocks.2.0.emb_layers.1.weight", "output_blocks.2.0.emb_layers.1.bias", "output_blocks.2.0.out_layers.0.weight", "output_blocks.2.0.out_layers.0.bias", "output_blocks.2.0.out_layers.3.weight", "output_blocks.2.0.out_layers.3.bias", "output_blocks.2.0.skip_connection.weight", "output_blocks.2.0.skip_connection.bias", "output_blocks.2.1.norm.weight", "output_blocks.2.1.norm.bias", "output_blocks.2.1.qkv.weight", "output_blocks.2.1.qkv.bias", "output_blocks.2.1.proj_out.weight", "output_blocks.2.1.proj_out.bias", "output_blocks.3.0.in_layers.0.weight", "output_blocks.3.0.in_layers.0.bias", "output_blocks.3.0.in_layers.2.weight", "output_blocks.3.0.in_layers.2.bias", "output_blocks.3.0.emb_layers.1.weight", "output_blocks.3.0.emb_layers.1.bias", "output_blocks.3.0.out_layers.0.weight", "output_blocks.3.0.out_layers.0.bias", "output_blocks.3.0.out_layers.3.weight", "output_blocks.3.0.out_layers.3.bias", "output_blocks.3.0.skip_connection.weight", "output_blocks.3.0.skip_connection.bias", "output_blocks.3.1.norm.weight", "output_blocks.3.1.norm.bias", "output_blocks.3.1.qkv.weight", "output_blocks.3.1.qkv.bias", "output_blocks.3.1.proj_out.weight", "output_blocks.3.1.proj_out.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.4.0.in_layers.0.weight", "output_blocks.4.0.in_layers.0.bias", "output_blocks.4.0.in_layers.2.weight", "output_blocks.4.0.in_layers.2.bias", "output_blocks.4.0.emb_layers.1.weight", "output_blocks.4.0.emb_layers.1.bias", "output_blocks.4.0.out_layers.0.weight", "output_blocks.4.0.out_layers.0.bias", "output_blocks.4.0.out_layers.3.weight", "output_blocks.4.0.out_layers.3.bias", "output_blocks.4.0.skip_connection.weight", "output_blocks.4.0.skip_connection.bias", "output_blocks.5.0.in_layers.0.weight", "output_blocks.5.0.in_layers.0.bias", "output_blocks.5.0.in_layers.2.weight", "output_blocks.5.0.in_layers.2.bias", "output_blocks.5.0.emb_layers.1.weight", "output_blocks.5.0.emb_layers.1.bias", "output_blocks.5.0.out_layers.0.weight", "output_blocks.5.0.out_layers.0.bias", "output_blocks.5.0.out_layers.3.weight", "output_blocks.5.0.out_layers.3.bias", "output_blocks.5.0.skip_connection.weight", "output_blocks.5.0.skip_connection.bias", "output_blocks.5.1.in_layers.0.weight", "output_blocks.5.1.in_layers.0.bias", "output_blocks.5.1.in_layers.2.weight", "output_blocks.5.1.in_layers.2.bias", "output_blocks.5.1.emb_layers.1.weight", "output_blocks.5.1.emb_layers.1.bias", "output_blocks.5.1.out_layers.0.weight", "output_blocks.5.1.out_layers.0.bias", "output_blocks.5.1.out_layers.3.weight", "output_blocks.5.1.out_layers.3.bias", "output_blocks.6.0.in_layers.0.weight", "output_blocks.6.0.in_layers.0.bias", "output_blocks.6.0.in_layers.2.weight", "output_blocks.6.0.in_layers.2.bias", "output_blocks.6.0.emb_layers.1.weight", "output_blocks.6.0.emb_layers.1.bias", "output_blocks.6.0.out_layers.0.weight", "output_blocks.6.0.out_layers.0.bias", "output_blocks.6.0.out_layers.3.weight", "output_blocks.6.0.out_layers.3.bias", "output_blocks.6.0.skip_connection.weight", "output_blocks.6.0.skip_connection.bias", "output_blocks.7.0.in_layers.0.weight", "output_blocks.7.0.in_layers.0.bias", "output_blocks.7.0.in_layers.2.weight", "output_blocks.7.0.in_layers.2.bias", "output_blocks.7.0.emb_layers.1.weight", "output_blocks.7.0.emb_layers.1.bias", "output_blocks.7.0.out_layers.0.weight", "output_blocks.7.0.out_layers.0.bias", "output_blocks.7.0.out_layers.3.weight", "output_blocks.7.0.out_layers.3.bias", "output_blocks.7.0.skip_connection.weight", "output_blocks.7.0.skip_connection.bias", "output_blocks.7.1.in_layers.0.weight", "output_blocks.7.1.in_layers.0.bias", "output_blocks.7.1.in_layers.2.weight", "output_blocks.7.1.in_layers.2.bias", "output_blocks.7.1.emb_layers.1.weight", "output_blocks.7.1.emb_layers.1.bias", "output_blocks.7.1.out_layers.0.weight", "output_blocks.7.1.out_layers.0.bias", "output_blocks.7.1.out_layers.3.weight", "output_blocks.7.1.out_layers.3.bias", "output_blocks.8.0.in_layers.0.weight", "output_blocks.8.0.in_layers.0.bias", "output_blocks.8.0.in_layers.2.weight", "output_blocks.8.0.in_layers.2.bias", "output_blocks.8.0.emb_layers.1.weight", "output_blocks.8.0.emb_layers.1.bias", "output_blocks.8.0.out_layers.0.weight", "output_blocks.8.0.out_layers.0.bias", "output_blocks.8.0.out_layers.3.weight", "output_blocks.8.0.out_layers.3.bias", "output_blocks.8.0.skip_connection.weight", "output_blocks.8.0.skip_connection.bias", "output_blocks.9.0.in_layers.0.weight", "output_blocks.9.0.in_layers.0.bias", "output_blocks.9.0.in_layers.2.weight", "output_blocks.9.0.in_layers.2.bias", "output_blocks.9.0.emb_layers.1.weight", "output_blocks.9.0.emb_layers.1.bias", "output_blocks.9.0.out_layers.0.weight", "output_blocks.9.0.out_layers.0.bias", "output_blocks.9.0.out_layers.3.weight", "output_blocks.9.0.out_layers.3.bias", "output_blocks.9.0.skip_connection.weight", "output_blocks.9.0.skip_connection.bias", "output_blocks.9.1.in_layers.0.weight", "output_blocks.9.1.in_layers.0.bias", "output_blocks.9.1.in_layers.2.weight", "output_blocks.9.1.in_layers.2.bias", "output_blocks.9.1.emb_layers.1.weight", "output_blocks.9.1.emb_layers.1.bias", "output_blocks.9.1.out_layers.0.weight", "output_blocks.9.1.out_layers.0.bias", "output_blocks.9.1.out_layers.3.weight", "output_blocks.9.1.out_layers.3.bias", "output_blocks.10.0.in_layers.0.weight", "output_blocks.10.0.in_layers.0.bias", "output_blocks.10.0.in_layers.2.weight", "output_blocks.10.0.in_layers.2.bias", "output_blocks.10.0.emb_layers.1.weight", "output_blocks.10.0.emb_layers.1.bias", "output_blocks.10.0.out_layers.0.weight", "output_blocks.10.0.out_layers.0.bias", "output_blocks.10.0.out_layers.3.weight", "output_blocks.10.0.out_layers.3.bias", "output_blocks.10.0.skip_connection.weight", "output_blocks.10.0.skip_connection.bias", "output_blocks.11.0.in_layers.0.weight", "output_blocks.11.0.in_layers.0.bias", "output_blocks.11.0.in_layers.2.weight", "output_blocks.11.0.in_layers.2.bias", "output_blocks.11.0.emb_layers.1.weight", "output_blocks.11.0.emb_layers.1.bias", "output_blocks.11.0.out_layers.0.weight", "output_blocks.11.0.out_layers.0.bias", "output_blocks.11.0.out_layers.3.weight", "output_blocks.11.0.out_layers.3.bias", "output_blocks.11.0.skip_connection.weight", "output_blocks.11.0.skip_connection.bias", "out.0.weight", "out.0.bias", "out.2.weight", "out.2.bias". Unexpected key(s) in state_dict: "state", "param_groups".

gracezhao1997 commented 1 year ago

run_train_dse.py is used for training domain-specific extractor, where the saved model.pt. is the trained classifier and opt.pt is the optimizer. EGSDE needs a pretrained diffusion model and domain-specific extractor. When running run_EGSDE.py, you should set ckpt in args.py to the path of diffusion model rather the optimizer opt.pt. In other words, you need to train a diffusion model on your own target domain data and the code for training diffusion model is available at https://github.com/openai/guided-diffusion or https://github.com/ermongroup/ddim.

xie-qiang commented 1 year ago

Hello, I want to ask what is the use of the optimizer opt.pt, thank you!

gracezhao1997 commented 1 year ago

The optimizer opt.pt is only used for resume training to avoid training from scratch when the program interruption.

xie-qiang commented 1 year ago

I have another question. I see that when running run_train_dse.py, I need a pretrained_model. If I want to run run_train_dse.py on my custom dataset for training, can I use 256x256_classifier.pt as pretrained_model or do I need to retrain a classifier myself? Thank you!

gracezhao1997 commented 1 year ago

Considering 256x256_classifier.pt is trained on the large ImageNet, I guess that directly using it for image translation is also helpful, where your dataset is more similar to ImageNet and the result is better. In my experiments, the way that uses 256x256_classifier.pt as initial weight and finetunes it on our own data achieves the best results. Therefore, finetuning it on your dataset can improve the performance further.

mmitchef commented 2 months ago

@xie-qiang @gracezhao1997 @LinjieFu-U @zhenxuan00 @aykborstelmann Hi, I want to train EGDE on my own dataset and I am a little confuses about steps. Could you please let me know if the below steps are the steps that I need to follow to run EGSDE:

  1. Train a DDPM model only on my target image
  2. Run the code run_train_dse.py on both domains (source and target) to get domain specific features.
  3. Run run_EGSDE.py while setting ckpt in args.py to the path of pre-trained diffusion model from step 1.

Thank you