tencent-ailab / IP-Adapter

The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt.
Apache License 2.0
5.35k stars 339 forks source link

Training of Ip_adpter_sdxl with ip_adapter_plus #323

Open Bilal143260 opened 8 months ago

Bilal143260 commented 8 months ago

Thank you for this amazing project. I want to train ip adapter sdxl but I want to use ip adapter plus. What modifications do I have to make in tutorial_train_sdxl.py? Is it just the number of tokens to query from the CLIP image encoding?

regards

xiaohu2015 commented 8 months ago

you can follow https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train_plus.py

reesekneeland commented 3 months ago

Hello @xiaohu2015! Can you provide more clarity on how to get the ip-adapter-plus_sdxl_vit-h training configuration set up? The tutorial_train_plus.py file doesn't seem to be compatible with the requirements of the sdxl models (only one text encoder, etc).

I'm specifically interested in fine tuning the ip-adapter-plus_sdxl_vit-h.bin checkpoint and could use some guidance on setting this up, I tried implementing a new training file to use both plus and sdxl but ran into many incorrect checkpoint dimensions when trying to load the model weights.

fkjkey commented 1 day ago

@reesekneeland tutorial_train_sdxl.py: Change

# image_proj_model = ImageProjModel(
#     cross_attention_dim=unet.config.cross_attention_dim,
#     clip_embeddings_dim=image_encoder.config.projection_dim,
#     clip_extra_context_tokens=num_tokens,
# )

To

image_proj_model = Resampler(
        dim=1280,
        depth=4,
        dim_head=64,
        heads=20,
        num_queries=num_tokens,
        embedding_dim=image_encoder.config.hidden_size,
        output_dim=2048,
        ff_mult=4
    )

like https://github.com/tencent-ailab/IP-Adapter/issues/170 and Change

# with torch.no_grad():
#     image_embeds = image_encoder(
#         batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
# image_embeds_ = []
# for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
#     if drop_image_embed == 1:
#         image_embeds_.append(torch.zeros_like(image_embed))
#     else:
#         image_embeds_.append(image_embed)
# image_embeds = torch.stack(image_embeds_)

To

clip_images = []
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
    if drop_image_embed == 1:
        clip_images.append(torch.zeros_like(clip_image))
    else:
        clip_images.append(clip_image)
clip_images = torch.stack(clip_images, dim=0)
with torch.no_grad():
    image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype),
                                 output_hidden_states=True).hidden_states[-2]

Then,I can load ip-adapter-plus_sdxl_vit-h.bin,and train