zer0int / CLIP-fine-tune

Fine-tuning code for CLIP models
MIT License
136 stars 7 forks source link

loading weights with `clip.load()` throws `EOFError: Ran out of input` #11

Closed SkyLull closed 2 weeks ago

SkyLull commented 2 weeks ago

Description

I'm trying to load CLIP in my code, but it does not work. I bypass the sha256 check in clip.load() simply by passing the file name in, in this case, it will just load that model. The code can load original CLIP models successfully using the same method. clip.load("/home/user/.cache/clip/ViT-L-14.pt") I'm suspecting this might have something to do with a different torch (or python) version, can you provide me with your local environment details?

I will list what I have done down below:

Environment

Steps

python

Python 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

import clip clip.load("ViT-L-14-TEXT-detail-improved-hiT-GmP-state_dict.pt")


/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/clip/clip.py:136: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(opened_file, map_location="cpu")
Traceback (most recent call last):
File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/clip/clip.py", line 129, in load
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/jit/_serialization.py", line 165, in load
cpp_module = torch._C.import_ir_module_from_buffer(
RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "", line 1, in File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/clip/clip.py", line 136, in load state_dict = torch.load(opened_file, map_location="cpu") File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/serialization.py", line 1114, in load return _legacy_load( File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/serialization.py", line 1338, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) EOFError: Ran out of input

>>> clip.load("ViT-L-14-TEXT-detail-improved-hiT-GmP-pickle-OpenAI.pt")

Traceback (most recent call last): File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/clip/clip.py", line 129, in load model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/jit/_serialization.py", line 165, in load cpp_module = torch._C.import_ir_module_from_buffer( RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "", line 1, in File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/clip/clip.py", line 136, in load state_dict = torch.load(opened_file, map_location="cpu") File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/serialization.py", line 1114, in load return _legacy_load( File "/home/jack/miniconda3/envs/oi/lib/python3.10/site-packages/torch/serialization.py", line 1338, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) EOFError: Ran out of input

zer0int commented 2 weeks ago

Judging by the filename, you have downloaded the state_dict version (and not the full model object).

You can either:

  1. Download the full pickle instead and use that (also search for anything that says jit= and remove it, if present, if you get another jit error; my model is just a torch.save, not a jit archive).

  2. Load the state_dict you have already downloaded into the ViT-L/14 model:

    model, preprocess = clip.load("ViT-L/14")
    state_dict = torch.load("path/to/ViT-L-14-TEXT-detail-improved-hiT-GmP-state_dict.pt", map_location="cpu")
    model.load_state_dict(state_dict)
    model.eval() # or .train(), depending on what you're trying to do

Hope that helps!

SkyLull commented 2 weeks ago

THANK YOU VERY MUCH! It worked! I did not realize it work this way. Thank you for your work, your time and your guide!