Closed Vinicius-ufsc closed 1 year ago
Looks like the error shows up when importing torch, there might be something wrong with the installation. I'd probably try to do a clean install. See also https://github.com/pytorch/pytorch/issues/51080#issuecomment-787133939
Installing torch with pip did solve the torch error https://github.com/pytorch/pytorch/issues/51080#issuecomment-787019521
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
but now I'm getting an UnpicklingError.
Maybe I'm not passing the expected files to --load argument
can you explain me please what should be the .pt files?
I'm using the original ViT-L-14-336px.pt
file from CLIP OpenAI
the file is download automatically when loading the model using clip:
import clip
model, preprocess = clip.load('ViT-L/14@336px', device=device, jit=False)
the fine tuned model (ft_01_6ep_lr2e6.pt
) was trained in a custom dataset using ViT-L/14@336px
(E2E) with a script I made, and saved with:
import os
from torch import save, load
# save model
def save_model(model, optimizer, epoch, name):
# create model_checkpoint folder if not exist
if not os.path.exists('model_checkpoint'):
os.makedirs('model_checkpoint')
current_working_directory = os.getcwd()
model_name = f'{name}.pt'
save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f"{os.path.join('model_checkpoint', model_name)}")
print(f"Model saved at {os.path.join(current_working_directory, 'model_checkpoint', model_name)}")
then I run the command to interpolate the weights
python src/wise_ft.py
--load=/home/user/.cache/clip/ViT-L-14-336px.pt,/home/user/model_checkpoint/ft_01_6ep_lr2e6.pt
--results-db=results.jsonl
--save=models/wiseft
--data-location=~/data
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
and get this error:
WARNING:root:The WILDS package is out of date. Your version is 1.1.0, while the latest version is 2.0.0.
Loading image classifier from /home/user/.cache/clip/ViT-L-14-336px.pt
Traceback (most recent call last):
File "src/wise_ft.py", line 100, in <module>
wise_ft(args)
File "src/wise_ft.py", line 65, in wise_ft
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
File "/home/user/wise-ft/src/models/modeling.py", line 85, in load
return utils.torch_load(filename)
File "/home/user/wise-ft/src/models/utils.py", line 50, in torch_load
classifier = pickle.load(f)
_pickle.UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.
Can you provide me an example of models (.pt) that is tested and works? so I can see if the problem are my models
If you're trying to load with the ImageClassifier.load
, you should save with the ImageClassifier.save
function
Thank you for the support!
resume of this issue:
if you are running into the error:
Traceback (most recent call last):
File "wise_ft.py", line 5, in <module>
import torch
File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 189, in <module>
_load_global_deps()
File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 142, in _load_global_deps
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
File "/home/user/miniconda3/envs/wiseft/lib/python3.6/ctypes/__init__.py", line 348, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/lib/../../../../libcublas.so.11: undefined symbol: free_gemm_select, version libcublasLt.so.11
consider install torch with the following line:
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
You cannot load the model direct, first you need to create a ImageClassifier then save it with ImageClassifier.save function
Hi, I'm getting an OSError when trying to run the interpolation
I want to interpolate
ViT-L-14-336px.pt
with myfine-tuned.pt
model but can't solve this issue, any ideas?I ran the code below to create the env (no errors or warnings):
And the code to interpolate:
error: