zer0int / CLIP-fine-tune

Fine-tuning code for CLIP models
MIT License
136 stars 7 forks source link

Train a CLIP that was saved into a SDXL model? #3

Closed bash-j closed 3 months ago

bash-j commented 4 months ago

Hello again, I used your GmP fine tune script and it worked really well! Thank you.

Is it possible to extract a CLIP from a SDXL model that has already been fine tuned somewhat, and fine tune further with your script?

I tried making something that would build the CLIP model like you have in the covert GmP back to weight, but I'm getting a bunch of unexpected key errors.

This is the list of keys in the file that comfyui saves as a safetensor file:

['logit_scale'
 'text_model.embeddings.position_embedding.weight'
 'text_model.embeddings.token_embedding.weight'
 'text_model.encoder.layers.0.layer_norm1.bias'
 'text_model.encoder.layers.0.layer_norm1.weight'
 'text_model.encoder.layers.0.layer_norm2.bias'
 'text_model.encoder.layers.0.layer_norm2.weight'
 'text_model.encoder.layers.0.mlp.fc1.bias'
 'text_model.encoder.layers.0.mlp.fc1.weight'
 'text_model.encoder.layers.0.mlp.fc2.bias'
 'text_model.encoder.layers.0.mlp.fc2.weight'
 'text_model.encoder.layers.0.self_attn.k_proj.bias'
 'text_model.encoder.layers.0.self_attn.k_proj.weight'
 'text_model.encoder.layers.0.self_attn.out_proj.bias'
 'text_model.encoder.layers.0.self_attn.out_proj.weight'
 'text_model.encoder.layers.0.self_attn.q_proj.bias'
 'text_model.encoder.layers.0.self_attn.q_proj.weight'
 'text_model.encoder.layers.0.self_attn.v_proj.bias'
 'text_model.encoder.layers.0.self_attn.v_proj.weight'
 'text_model.encoder.layers.1.layer_norm1.bias'
 'text_model.encoder.layers.1.layer_norm1.weight'
 'text_model.encoder.layers.1.layer_norm2.bias'
 'text_model.encoder.layers.1.layer_norm2.weight'
 'text_model.encoder.layers.1.mlp.fc1.bias'
 'text_model.encoder.layers.1.mlp.fc1.weight'
 'text_model.encoder.layers.1.mlp.fc2.bias'
 'text_model.encoder.layers.1.mlp.fc2.weight'
 'text_model.encoder.layers.1.self_attn.k_proj.bias'
 'text_model.encoder.layers.1.self_attn.k_proj.weight'
 'text_model.encoder.layers.1.self_attn.out_proj.bias'
 'text_model.encoder.layers.1.self_attn.out_proj.weight'
 'text_model.encoder.layers.1.self_attn.q_proj.bias'
 'text_model.encoder.layers.1.self_attn.q_proj.weight'
 'text_model.encoder.layers.1.self_attn.v_proj.bias'
 'text_model.encoder.layers.1.self_attn.v_proj.weight'
 'text_model.encoder.layers.10.layer_norm1.bias'
 'text_model.encoder.layers.10.layer_norm1.weight'
 'text_model.encoder.layers.10.layer_norm2.bias'
 'text_model.encoder.layers.10.layer_norm2.weight'
 'text_model.encoder.layers.10.mlp.fc1.bias'
 'text_model.encoder.layers.10.mlp.fc1.weight'
 'text_model.encoder.layers.10.mlp.fc2.bias'
 'text_model.encoder.layers.10.mlp.fc2.weight'
 'text_model.encoder.layers.10.self_attn.k_proj.bias'
 'text_model.encoder.layers.10.self_attn.k_proj.weight'
 'text_model.encoder.layers.10.self_attn.out_proj.bias'
 'text_model.encoder.layers.10.self_attn.out_proj.weight'
 'text_model.encoder.layers.10.self_attn.q_proj.bias'
 'text_model.encoder.layers.10.self_attn.q_proj.weight'
 'text_model.encoder.layers.10.self_attn.v_proj.bias'
 'text_model.encoder.layers.10.self_attn.v_proj.weight'
 'text_model.encoder.layers.11.layer_norm1.bias'
 'text_model.encoder.layers.11.layer_norm1.weight'
 'text_model.encoder.layers.11.layer_norm2.bias'
 'text_model.encoder.layers.11.layer_norm2.weight'
 'text_model.encoder.layers.11.mlp.fc1.bias'
 'text_model.encoder.layers.11.mlp.fc1.weight'
 'text_model.encoder.layers.11.mlp.fc2.bias'
 'text_model.encoder.layers.11.mlp.fc2.weight'
 'text_model.encoder.layers.11.self_attn.k_proj.bias'
 'text_model.encoder.layers.11.self_attn.k_proj.weight'
 'text_model.encoder.layers.11.self_attn.out_proj.bias'
 'text_model.encoder.layers.11.self_attn.out_proj.weight'
 'text_model.encoder.layers.11.self_attn.q_proj.bias'
 'text_model.encoder.layers.11.self_attn.q_proj.weight'
 'text_model.encoder.layers.11.self_attn.v_proj.bias'
 'text_model.encoder.layers.11.self_attn.v_proj.weight'
 'text_model.encoder.layers.2.layer_norm1.bias'
 'text_model.encoder.layers.2.layer_norm1.weight'
 'text_model.encoder.layers.2.layer_norm2.bias'
 'text_model.encoder.layers.2.layer_norm2.weight'
 'text_model.encoder.layers.2.mlp.fc1.bias'
 'text_model.encoder.layers.2.mlp.fc1.weight'
 'text_model.encoder.layers.2.mlp.fc2.bias'
 'text_model.encoder.layers.2.mlp.fc2.weight'
 'text_model.encoder.layers.2.self_attn.k_proj.bias'
 'text_model.encoder.layers.2.self_attn.k_proj.weight'
 'text_model.encoder.layers.2.self_attn.out_proj.bias'
 'text_model.encoder.layers.2.self_attn.out_proj.weight'
 'text_model.encoder.layers.2.self_attn.q_proj.bias'
 'text_model.encoder.layers.2.self_attn.q_proj.weight'
 'text_model.encoder.layers.2.self_attn.v_proj.bias'
 'text_model.encoder.layers.2.self_attn.v_proj.weight'
 'text_model.encoder.layers.3.layer_norm1.bias'
 'text_model.encoder.layers.3.layer_norm1.weight'
 'text_model.encoder.layers.3.layer_norm2.bias'
 'text_model.encoder.layers.3.layer_norm2.weight'
 'text_model.encoder.layers.3.mlp.fc1.bias'
 'text_model.encoder.layers.3.mlp.fc1.weight'
 'text_model.encoder.layers.3.mlp.fc2.bias'
 'text_model.encoder.layers.3.mlp.fc2.weight'
 'text_model.encoder.layers.3.self_attn.k_proj.bias'
 'text_model.encoder.layers.3.self_attn.k_proj.weight'
 'text_model.encoder.layers.3.self_attn.out_proj.bias'
 'text_model.encoder.layers.3.self_attn.out_proj.weight'
 'text_model.encoder.layers.3.self_attn.q_proj.bias'
 'text_model.encoder.layers.3.self_attn.q_proj.weight'
 'text_model.encoder.layers.3.self_attn.v_proj.bias'
 'text_model.encoder.layers.3.self_attn.v_proj.weight'
 'text_model.encoder.layers.4.layer_norm1.bias'
 'text_model.encoder.layers.4.layer_norm1.weight'
 'text_model.encoder.layers.4.layer_norm2.bias'
 'text_model.encoder.layers.4.layer_norm2.weight'
 'text_model.encoder.layers.4.mlp.fc1.bias'
 'text_model.encoder.layers.4.mlp.fc1.weight'
 'text_model.encoder.layers.4.mlp.fc2.bias'
 'text_model.encoder.layers.4.mlp.fc2.weight'
 'text_model.encoder.layers.4.self_attn.k_proj.bias'
 'text_model.encoder.layers.4.self_attn.k_proj.weight'
 'text_model.encoder.layers.4.self_attn.out_proj.bias'
 'text_model.encoder.layers.4.self_attn.out_proj.weight'
 'text_model.encoder.layers.4.self_attn.q_proj.bias'
 'text_model.encoder.layers.4.self_attn.q_proj.weight'
 'text_model.encoder.layers.4.self_attn.v_proj.bias'
 'text_model.encoder.layers.4.self_attn.v_proj.weight'
 'text_model.encoder.layers.5.layer_norm1.bias'
 'text_model.encoder.layers.5.layer_norm1.weight'
 'text_model.encoder.layers.5.layer_norm2.bias'
 'text_model.encoder.layers.5.layer_norm2.weight'
 'text_model.encoder.layers.5.mlp.fc1.bias'
 'text_model.encoder.layers.5.mlp.fc1.weight'
 'text_model.encoder.layers.5.mlp.fc2.bias'
 'text_model.encoder.layers.5.mlp.fc2.weight'
 'text_model.encoder.layers.5.self_attn.k_proj.bias'
 'text_model.encoder.layers.5.self_attn.k_proj.weight'
 'text_model.encoder.layers.5.self_attn.out_proj.bias'
 'text_model.encoder.layers.5.self_attn.out_proj.weight'
 'text_model.encoder.layers.5.self_attn.q_proj.bias'
 'text_model.encoder.layers.5.self_attn.q_proj.weight'
 'text_model.encoder.layers.5.self_attn.v_proj.bias'
 'text_model.encoder.layers.5.self_attn.v_proj.weight'
 'text_model.encoder.layers.6.layer_norm1.bias'
 'text_model.encoder.layers.6.layer_norm1.weight'
 'text_model.encoder.layers.6.layer_norm2.bias'
 'text_model.encoder.layers.6.layer_norm2.weight'
 'text_model.encoder.layers.6.mlp.fc1.bias'
 'text_model.encoder.layers.6.mlp.fc1.weight'
 'text_model.encoder.layers.6.mlp.fc2.bias'
 'text_model.encoder.layers.6.mlp.fc2.weight'
 'text_model.encoder.layers.6.self_attn.k_proj.bias'
 'text_model.encoder.layers.6.self_attn.k_proj.weight'
 'text_model.encoder.layers.6.self_attn.out_proj.bias'
 'text_model.encoder.layers.6.self_attn.out_proj.weight'
 'text_model.encoder.layers.6.self_attn.q_proj.bias'
 'text_model.encoder.layers.6.self_attn.q_proj.weight'
 'text_model.encoder.layers.6.self_attn.v_proj.bias'
 'text_model.encoder.layers.6.self_attn.v_proj.weight'
 'text_model.encoder.layers.7.layer_norm1.bias'
 'text_model.encoder.layers.7.layer_norm1.weight'
 'text_model.encoder.layers.7.layer_norm2.bias'
 'text_model.encoder.layers.7.layer_norm2.weight'
 'text_model.encoder.layers.7.mlp.fc1.bias'
 'text_model.encoder.layers.7.mlp.fc1.weight'
 'text_model.encoder.layers.7.mlp.fc2.bias'
 'text_model.encoder.layers.7.mlp.fc2.weight'
 'text_model.encoder.layers.7.self_attn.k_proj.bias'
 'text_model.encoder.layers.7.self_attn.k_proj.weight'
 'text_model.encoder.layers.7.self_attn.out_proj.bias'
 'text_model.encoder.layers.7.self_attn.out_proj.weight'
 'text_model.encoder.layers.7.self_attn.q_proj.bias'
 'text_model.encoder.layers.7.self_attn.q_proj.weight'
 'text_model.encoder.layers.7.self_attn.v_proj.bias'
 'text_model.encoder.layers.7.self_attn.v_proj.weight'
 'text_model.encoder.layers.8.layer_norm1.bias'
 'text_model.encoder.layers.8.layer_norm1.weight'
 'text_model.encoder.layers.8.layer_norm2.bias'
 'text_model.encoder.layers.8.layer_norm2.weight'
 'text_model.encoder.layers.8.mlp.fc1.bias'
 'text_model.encoder.layers.8.mlp.fc1.weight'
 'text_model.encoder.layers.8.mlp.fc2.bias'
 'text_model.encoder.layers.8.mlp.fc2.weight'
 'text_model.encoder.layers.8.self_attn.k_proj.bias'
 'text_model.encoder.layers.8.self_attn.k_proj.weight'
 'text_model.encoder.layers.8.self_attn.out_proj.bias'
 'text_model.encoder.layers.8.self_attn.out_proj.weight'
 'text_model.encoder.layers.8.self_attn.q_proj.bias'
 'text_model.encoder.layers.8.self_attn.q_proj.weight'
 'text_model.encoder.layers.8.self_attn.v_proj.bias'
 'text_model.encoder.layers.8.self_attn.v_proj.weight'
 'text_model.encoder.layers.9.layer_norm1.bias'
 'text_model.encoder.layers.9.layer_norm1.weight'
 'text_model.encoder.layers.9.layer_norm2.bias'
 'text_model.encoder.layers.9.layer_norm2.weight'
 'text_model.encoder.layers.9.mlp.fc1.bias'
 'text_model.encoder.layers.9.mlp.fc1.weight'
 'text_model.encoder.layers.9.mlp.fc2.bias'
 'text_model.encoder.layers.9.mlp.fc2.weight'
 'text_model.encoder.layers.9.self_attn.k_proj.bias'
 'text_model.encoder.layers.9.self_attn.k_proj.weight'
 'text_model.encoder.layers.9.self_attn.out_proj.bias'
 'text_model.encoder.layers.9.self_attn.out_proj.weight'
 'text_model.encoder.layers.9.self_attn.q_proj.bias'
 'text_model.encoder.layers.9.self_attn.q_proj.weight'
 'text_model.encoder.layers.9.self_attn.v_proj.bias'
 'text_model.encoder.layers.9.self_attn.v_proj.weight'
 'text_model.final_layer_norm.bias'
 'text_model.final_layer_norm.weight'
 'text_projection.weight']
zer0int commented 4 months ago

Hi again!

Interesting - yeah, I can confirm that .safetensors object is indeed just the text transformer (makes sense, as it is the SDXL text encoder). But we'd need that "better half" of CLIP back to fine-tune it...

Now, that means you'd have to plug the text transformer you've fine-tuned together with a not-fine-tuned CLIP vision transformer. That is probably going to raise hell in the gradient norms, as CLIP has to adjust its entire projection of aligning text-images again. You could freeze the text transformer and just train / adjust vision, of course. Best without GmP, as that raises hell in gradients even when there's multimodal alignment at the beginning.

However, I am somewhat skeptic about this approach. I have never tried, but to puzzle together two non-matching (unaligned) transformers to be CLIP and then force-align one side, hmmm - sounds like trouble, unless you can do a batch_size of 2048. Especially as CLIP is already quite delicate with regard to being fine-tuned standalone (as we've discussed before).

Probably the only way would be to fine-tune CLIP first (standalone), and then put that into SDXL, turning CLIP into "the text encoder" - and train SDXL with a frozen TE, i.e. just aligning U-Net to that new CLIP again. But once you train the TE with U-Net, it is aligned to that generative AI system, and if you don't have a "better half" that is aligned to that TE, I assume it might be troublesome to fine-tune that CLIP as standalone again.

However, this is just "intuition with a bit of reasoning", I have never tried. Maybe CLIP can just align a vision transformer to its text transformer trained elsewhere. But again, I am skeptical, re: gradients re: consumer GPU batch_size limit and alas missing regularization that typically comes with large batch_sizes (and CLIP hating other regularizations such as gradient clipping etc.).

I tried asking GPT-4o real quick for a combo-breaker script to puzzle HF stuff together (TE + ViT-L/14) and then convert that to the original OpenAI model structure; however, so far, it still doesn't work. It has debug prints and all, though, so I'll share it - just in case you want to still try this (and maybe prove me wrong with regard to my concerns about plugging an unknown-to-CLIP vision transformer into your fine-tuned CLIP TE!):

DELETED, FIXED.

Unfortunately, I gotta join a lengthy video-call in 15, so all I have right now is this not-quite-working-yet script. Hope this helps, though - much success (or maybe just good luck?! 🫣)! πŸ™ƒ

zer0int commented 3 months ago

I just committed the folder "merge-SDXL-TE-into-full-CLIP-model-object". That will allow you to puzzle CLIP back together to a full CLIP model you can use with all of my scripts. I have tried this with an old TE I had previously fine-tuned with kohya, along with U-Net - and expectedly, its text-vision latent space / projection space is completely unaligned.

In other words, the model is completely bonkers. It cannot see even a faint glimpse of "cat" anymore, in this example. 🀣

Whether this can be fixed or not stands to be seen, but I guess I will play around with that now, too, as the way the model is crazy is kinda... Interesting. Happy Friday! 🀩

model-crazy

bash-j commented 3 months ago

Happy Friday!

Wow, thanks for figuring that out! What cat are you talking about? I don't see a cat. 🀣

zer0int commented 3 months ago

Update: Bad news is: You can't just fix a ruined CLIP like that. As a stand-alone CLIP, it still can't predict easy images (such as the cat you can't see either πŸ˜› ). The numbers from fine-tuning clearly show that, too.

Good news is: I just fine-tuned it ANYWAY, for 5 Epochs on CoCo-40k, i.e. <1 hour or so, and: It has restored coherence and brought the TE much closer to the INTENDED concept. It's not as good as a separate fine-tune, imo - but, for my dataset ("deepdream neurons"), it greatly improved the "ruined CLIP". Worth a shot if you don't want to fine-tune for many hours to have a separately trained CLIP TE!

Untitled-2

Untitled-1

So, I guess I gotta say "thanks for rising the issue!", because this is a useful thing to know, indeed! 😎 πŸ‘