openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
24.55k stars 3.2k forks source link

RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' #415

Closed Viditagarwal7479 closed 8 months ago

Viditagarwal7479 commented 8 months ago

Hello, I first converted the model to fp16. Like this,

model, _ = clip.load("ViT-B/32")
clip.model.convert_weights(model)
torch.save(model, "modelfp16.pth")

As the above function is intended to convert the model from fp32 to fp16 then what can be the possible reason it's not working?

When I used this saved model, it gave the error with the following traceback:

Traceback

    text_features = model.encode_text(tokenized_text)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 348, in encode_text
    x = self.transformer(x)
        ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 203, in forward
    return self.resblocks(x)
           ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 190, in forward
    x = x + self.attention(self.ln_1(x))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 187, in attention
    return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/functional.py", line 5440, in multi_head_attention_forward
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: c10::Half and  query.dtype: c10::BFloat16 instead.
(mlx) vidit@Vidits-MacBook-Air image_retrival % python main.py
Enter Query:yellow tshirt
yellow tshirt
tensor(14.2857)
Traceback (most recent call last):
  File "/Users/vidit/image_retrival/main.py", line 60, in <module>
    print(check_score(model, search_prompt, cursor))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vidit/image_retrival/main.py", line 31, in check_score
    text_features = model.encode_text(tokenized_text)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 348, in encode_text
    x = self.transformer(x)
        ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 203, in forward
    return self.resblocks(x)
           ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 190, in forward
    x = x + self.attention(self.ln_1(x))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/clip/model.py", line 187, in attention
    return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/mlx/lib/python3.11/site-packages/torch/nn/functional.py", line 4824, in _in_projection_packed
    proj = linear(q, w, b)
           ^^^^^^^^^^^^^^^
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
Viditagarwal7479 commented 8 months ago

That layer is implemented to work with fp16 on the GPU. In the case of using on CPU use fp32.