Open Bilal143260 opened 8 months ago
Hello @xiaohu2015! Can you provide more clarity on how to get the ip-adapter-plus_sdxl_vit-h training configuration set up? The tutorial_train_plus.py file doesn't seem to be compatible with the requirements of the sdxl models (only one text encoder, etc).
I'm specifically interested in fine tuning the ip-adapter-plus_sdxl_vit-h.bin checkpoint and could use some guidance on setting this up, I tried implementing a new training file to use both plus and sdxl but ran into many incorrect checkpoint dimensions when trying to load the model weights.
@reesekneeland tutorial_train_sdxl.py: Change
# image_proj_model = ImageProjModel(
# cross_attention_dim=unet.config.cross_attention_dim,
# clip_embeddings_dim=image_encoder.config.projection_dim,
# clip_extra_context_tokens=num_tokens,
# )
To
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=num_tokens,
embedding_dim=image_encoder.config.hidden_size,
output_dim=2048,
ff_mult=4
)
like https://github.com/tencent-ailab/IP-Adapter/issues/170 and Change
# with torch.no_grad():
# image_embeds = image_encoder(
# batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
# image_embeds_ = []
# for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
# if drop_image_embed == 1:
# image_embeds_.append(torch.zeros_like(image_embed))
# else:
# image_embeds_.append(image_embed)
# image_embeds = torch.stack(image_embeds_)
To
clip_images = []
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
if drop_image_embed == 1:
clip_images.append(torch.zeros_like(clip_image))
else:
clip_images.append(clip_image)
clip_images = torch.stack(clip_images, dim=0)
with torch.no_grad():
image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype),
output_hidden_states=True).hidden_states[-2]
Then,I can load ip-adapter-plus_sdxl_vit-h.bin,and train
Thank you for this amazing project. I want to train ip adapter sdxl but I want to use ip adapter plus. What modifications do I have to make in tutorial_train_sdxl.py? Is it just the number of tokens to query from the CLIP image encoding?
regards