tencent-ailab / IP-Adapter

The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt.
Apache License 2.0
4.5k stars 296 forks source link

SDXL Full pretrained model #289

Open gsrujana opened 5 months ago

gsrujana commented 5 months ago

Hi, I am trying to train SDXL with full CLIP embeddings and want to start with your pretrained weights. I modified "num_tokens" in tutorial_train_sdxl.py. Do you have a pretrained version of ipadapter-sdxl-full? I get dimension mismatch error if I start from your sdxl ipadapter model. Thanks

xiaohu2015 commented 5 months ago

you shoud use this image projection model: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L49

gsrujana commented 5 months ago
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
  File "/opt/conda/envs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MLPProjModel:
        Missing key(s) in state_dict: "proj.0.weight", "proj.0.bias", "proj.2.weight", "proj.2.bias", "proj.3.weight", "proj.3.bias". 
        Unexpected key(s) in state_dict: "norm.weight", "norm.bias", "proj.weight", "proj.bias". 
    self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
  File "/opt/conda/envs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MLPProjModel:
        Missing key(s) in state_dict: "proj.0.weight", "proj.0.bias", "proj.2.weight", "proj.2.bias", "proj.3.weight", "proj.3.bias". 
        Unexpected key(s) in state_dict: "norm.weight", "norm.bias", "proj.weight", "proj.bias". 
Traceback (most recent call last):
  File "tutorial_train_sdxl.py", line 487, in <module>
    main()    
  File "tutorial_train_sdxl.py", line 377, in main
    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
  File "tutorial_train_sdxl.py", line 161, in __init__
    self.load_from_checkpoint(ckpt_path)
  File "tutorial_train_sdxl.py", line 178, in load_from_checkpoint
    self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
  File "/opt/conda/envs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MLPProjModel:
        Missing key(s) in state_dict: "proj.0.weight", "proj.0.bias", "proj.2.weight", "proj.2.bias", "proj.3.weight", "proj.3.bias". 
        Unexpected key(s) in state_dict: "norm.weight", "norm.bias", "proj.weight", "proj.bias". 
gsrujana commented 5 months ago

I am using your SDXL IP_adapter.bin as pretrained_ip_adapter_path. I changed from ImgProjModel to MLProjModel in tutorial_train_sdxl. Any suggestions @xiaohu2015? Thanks!

gsrujana commented 5 months ago

Further, how to train using ip-adapter-plus_sdxl_vit-h.bin as starting point? I am training on clothing data.