openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
25.44k stars 3.27k forks source link

clip.model.build_model does not work if device is cpu #209

Open Toekan opened 2 years ago

Toekan commented 2 years ago

Hi,

Thanks for providing this really convenient package to use the CLIP model!

I've come across a problem with build_model when trying to reconstruct the model from a state_dict on my local computer without GPU.

Code to reproduce

First I download one of the built-in models and save the state_dict:

model, preprocess = clip.load("ViT-B/32", jit=False, device="cpu") torch.save(model.state_dict(), 'clip_off_the_shelve.pt')

Then I load the model using your function and try to use it to infer a text embedding:

model = clip.model.build_model(torch.load('clip_off_the_shelve.pt')) text_tokens = clip.tokenize(["door"]) with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True)

This unfortunately results in:

~/Envs/test_env/lib/python3.8/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel, dtype)
   1678         dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
   1679     if dtype is None:
-> 1680         ret = input.softmax(dim)
   1681     else:
   1682         ret = input.softmax(dim, dtype=dtype)

RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Half'

Potential fix (?)

From reading around, it seems like the culprit is convert_weights in build_model, which converts weights to fp16, regardless off the device being used. Pytorch doesn't support fp16 on "cpu" which seems to create the above error. Would it be possible to make build_model conditional on the device?

Thanks!

jongwook commented 2 years ago

The conditional code can be found here in clip.load()

https://github.com/openai/CLIP/blob/3482bb6ed319f70542094d1ed224c0db0b88c3a5/clip/clip.py#L138-L141

and clip.load("clip_off_the_shelve.pt") should work; please let me know if it doesn't.

w1redch4d commented 2 years ago

Facing the same error, and @jongwook clip.load("clip_off_the_shelve.pt") doesnt work as well

jongwook commented 2 years ago

By clip_off_the_shelve.pt I meant the models downloaded under ~/.cache/clip. Let me know what the stacktrace looks like if you see an error loading those models with clip.load().

w1redch4d commented 2 years ago

The stackrace:

Traceback (most recent call last):
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 109, in <module>
    Predictor().predict()
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 25, in predict
    images = VQ_Diffusion_model.generate_sample_with_condition(
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 90, in generate_sample_with_condition
    model_out = self.model.generate_content(
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\models\dalle.py", line 215, in generate_content
    trans_out = self.transformer.sample(condition_token=condition['condition_token'],
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\transformers\diffusion_transformer.py", line 578, in sample
    cond_emb = self.condition_emb(input['condition_token']) # B x Ld x D   #256*1024
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\embeddings\clip_text_embedding.py", line 72, in forward
    text_feature = self.encode_text(index)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\embeddings\clip_text_embedding.py", line 52, in encode_text
    x = self.transformer(x)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 198, in forward
    return self.resblocks(x)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
    input = module(input)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 185, in forward
    x = x + self.attention(self.ln_1(x))
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 182, in attention
    return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\activation.py", line 1038, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 5358, in multi_head_attention_forward
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 5037, in _scaled_dot_product_attention
    attn = softmax(attn, dim=-1)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 1818, in softmax
    ret = input.softmax(dim)
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Half'

any how the error goes away implementing a simple check in the build_model functionality:

def build_model(state_dict: dict, device: str):
    .....
    if str(device) != "cpu":
        convert_weights(model)

    model.load_state_dict(state_dict)
    return model.eval()