zer0int / CLIP-fine-tune

Fine-tuning code for CLIP models
MIT License
169 stars 8 forks source link

I want to fine-tune a complete text encoder model, but it seems that the model trained by ft-B-train-OpenAI-CLIP-ViT-L-14.py is a visual encoder model. #16

Open vxiaobai opened 1 month ago

vxiaobai commented 1 month ago

First of all, thank you for your work. I have a question for you. I want to fine-tune a complete text encoder model, but it seems that the model trained by ft-B-train-OpenAI-CLIP-ViT-L-14.py is a visual encoder model. How can I get the model of the pure text encoder ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors given in your HF?

zer0int commented 1 month ago

The fine-tune is actually a text-vision model, consisting of a text transformer AND a vision transformer. For the "TE only" / text encoder only models on my HuggingFace, I fine-tuned the entire CLIP model (text + vision) and then simply "detached" the vision transformer (i.e. delete the keys / associated parameters). CLIP's objective is in the name - Contrastive Language-Image Pretraining. Learning both text and image, optimizing for dot-product of matching pairs (high) vs. negative examples (low), is the objective / optimization goal. It needs both image and text to be a "CLIP", per definition.

So, the question is - what are you trying to archive? Or do you mean that you only want to train the text encoder, with a frozen visual encoder (no parameter updates)? In that case:

The vision transformer is visual.transformer.resblocks[i] (and visual.proj and so on), the text transformer is transformer.resblocks[i] (no 'visual'). Alas, to only train the text encoder parameters while keeping the visual encoder frozen (but still using a contrastive loss between text-image), you could use something like this:

def freeze_clip_selectively(model):
    for name, param in model.named_parameters():
        if any(key in name for key in [
            'visual'
        ]):
            param.requires_grad = False
        else:
            param.requires_grad = True

# in trainloop(), before "for epoch [...]":

freeze_clip_selectively(model)
vxiaobai commented 1 month ago

For the "TE only" / text encoder only models on my HuggingFace, I fine-tuned the entire CLIP model (text + vision) and then simply "detached" the vision transformer (i.e. delete the keys / associated parameters).

Can you please give me the code for this, I want to use it with the flux model, I tested the text only encoder model you provided on HF and it works with the flux model, and now I want to train the CLIP model as a multi-lingual model, but I am not familiar with the steps to "separate" the vision transformer. I would like your help, thank you very much. Also thank you very much for your code, I learned a lot about multimodality from it.

zer0int commented 1 month ago

I just committed Convert-for-HuggingFace-Spaces-etc - the folder contains all the scripts + documentation / how-to use. Please let me know if that works for you!

vxiaobai commented 1 month ago

I just committed Convert-for-HuggingFace-Spaces-etc - the folder contains all the scripts + documentation / how-to use. Please let me know if that works for you!我刚刚提交了 Convert-for-HuggingFace-Spaces-etc - 该文件夹包含所有脚本+文档/如何使用。请告诉我这是否适合您!

Thank you very much. I think the code you provided is what I want, but I encountered some problems when converting. The error message is below. I would like to ask if you have encountered the same problem. I am trying to train several of your training programs separately, and then try each one: state_dict = torch.load(opened_file, map_location="cpu") Traceback (most recent call last): File "/opt/conda/lib/python3.11/site-packages/clip/clip.py", line 129, in load model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() ^ ... "/opt/conda/lib/python3.11/site-packages/torch/jit/_serialization.py", line 165, in load cpp_module = torch._C.import_ir_module_from_buffer( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: PytorchStreamReader failed locating file constant s.pkl: file not found During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/workspace/finetune_CLIP/CLIP-fine-tune/Convert-for-HuggingFace-Spaces-etc/convert_clip_original_pytorch_to_hf.py", line 156, in convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, kwargs) ^^^^^^^^^^^^^^^^^^^ File "/workspace/finetune_CLIP/CLIP -fine-tune/Convert-for-HuggingFace-Spaces-etc/convert_clip_original_pytorch_to_hf.py", line 120, in convert_clip_checkpoint ptmodel, = load(checkpoint_path, device="cpu", jit=False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/clip/clip.py", line 136, in load state_dict = torch.load(opened_file, map_location="cpu") ^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1114, in load return _legacy_load( ^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1114, line 1338, in _legacy_load magic_number = pickle_module.load(f, pickle_load_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ EOFError: Ran out of input

zer0int commented 1 month ago

Can you open /opt/conda/lib/python3.11/site-packages/clip/clip.py and edit line 129? Where it says:

model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() -> Change that to:

#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.load(opened_file, map_location="cpu").eval()

I can't reproduce your error, but somebody else reported the same; I am assuming it might be related to the venv / conda, and trying to load a torch jit scripted archive. I don't use a venv.

However, torch.jit is just for "interoperability, speed and production environments", so it's not needed, and we can just put the map_location on CPU in any case.

If that doesn't work, my other random guess at a fix (as I can't reproduce the problem): Can you use my ft-C-convert-for-SDXL-comfyUI-OpenAI-CLIP.py script (converts the full model to a state_dict), and try loading this converted model for the conversion instead?

vxiaobai commented 1 month ago

Can you open /opt/conda/lib/python3.11/site-packages/clip/clip.py and edit line 129? Where it says:

model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() -> Change that to:

#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.load(opened_file, map_location="cpu").eval()

I can't reproduce your error, but somebody else reported the same; I am assuming it might be related to the venv / conda, and trying to load a torch jit scripted archive. I don't use a venv.

However, torch.jit is just for "interoperability, speed and production environments", so it's not needed, and we can just put the map_location on CPU in any case.

If that doesn't work, my other random guess at a fix (as I can't reproduce the problem): Can you use my ft-C-convert-for-SDXL-comfyUI-OpenAI-CLIP.py script (converts the full model to a state_dict), and try loading this converted model for the conversion instead?

I learned the cause of this error in other forums and tried to solve the problem with it. It worked, but I'm not sure if it was the final factor. If you save the model with torch.save(model, model_path) , then load it with model = torch.load(opened_file, map_location="cpu").eval() . If you need to load it with model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() , then save the model with script_model = torch.jit.trace(model, (images, texts)) script_model.save("model.jit.pt") . Hope this helps with your work library.

zer0int commented 1 month ago

Thank you for the suggestion, and glad you got it to work! I'll try it and consider implementing as a Bool to switch - to True if you want to script the model, else save a normal torch.save, with my next update. 👍

zer0int commented 2 weeks ago

I updated the code with a new model saver; you can now choose to either save as GmP (legacy behavior) or directly convert back to .weight (original OpenAI/CLIP; no extra script for conversion needed anymore!). Plus, you can save the model as 1. a full model object (legacy behavior) or 2. a state_dict or 3. a torch.jit.trace() -- or all of those combined.

Hope it's useful to you! 👍