lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.66k stars 216 forks source link

LSQ half precision problem #116 #145

Closed denadai2 closed 4 months ago

denadai2 commented 4 months ago

I believe there is a similar problem to #116.

  File "/tmp/ray/session_2024-06-30_09-41-50_254745_1/runtime_resources/pip/ed0d17a5f9a959a3d03116db0bba20a6c15cac27/virtualenv/lib/python3.10/site-packages/vector_quantize_pytorch/lookup_free_quantization.py", line 273, in forward
    x = self.project_in(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

thxxx

PS: I'd like to use it as float32 within a bfloat16 module in FSDP but I do not know how

lucidrains commented 4 months ago

@denadai2 hey Marco! thanks for testing out this new quantizer

could you see if the latest version fixes your issue?

lucidrains commented 4 months ago

@denadai2 are you seeing a lot of success with this technique?

denadai2 commented 4 months ago

@denadai2 are you seeing a lot of success with this technique?

zero success until now ahah but I'll keep you updated! What about you?

lucidrains commented 4 months ago

@denadai2 i'm seeing better results with FSQ

but i haven't tried this spherical flavor on any real data just yet

denadai2 commented 4 months ago

I see! I'll try that one as well. I just have to better understand the paper :P

  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/tmp/ray/session_2024-06-30_09-41-50_254745_1/runtime_resources/pip/b1dd9d9db9545febf3d5ce2059c5b9fc44317bfb/virtualenv/lib/python3.10/site-packages/vector_quantize_pytorch/lookup_free_quantization.py", line 321, in forward
    distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
  File "/opt/conda/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: expected scalar type Float but found Half

it kinda work better than before, but we have this.

denadai2 commented 4 months ago

question: this https://github.com/lucidrains/vector-quantize-pytorch/blob/1bce1c3b80296f64612f808942460c3a955dec3f/vector_quantize_pytorch/lookup_free_quantization.py#L244 disables autocast while the recent #145 enables it by default. Should we enable it only if it is already enabled?

lucidrains commented 4 months ago

@denadai2 oops, try one more time?

lucidrains commented 4 months ago

@denadai2 working?

lucidrains commented 4 months ago

@denadai2 also try latest version with this setting turned False

denadai2 commented 4 months ago

Thanks @lucidrains!! I'll test it tomorrow or the day after. Got caught by bugs in the pipeline before it eheh

denadai2 commented 4 months ago

woooorkiiiinnggg! thx @lucidrains . I'll keep you updated with the exps :))

lucidrains commented 4 months ago

@denadai2 happy training Marco!

hummat commented 2 months ago

question: this

https://github.com/lucidrains/vector-quantize-pytorch/blob/1bce1c3b80296f64612f808942460c3a955dec3f/vector_quantize_pytorch/lookup_free_quantization.py#L244

disables autocast while the recent #145 enables it by default. Should we enable it only if it is already enabled?

Sorry for opening this again, but while LFQ works with AMP, FSQ still doesn't for me due to this line. I still get: RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16. Removing the line resolves the issue.

lucidrains commented 2 months ago

@hummat i think you are on an older version of the library

hummat commented 2 months ago

@lucidrains thanks for the quick reply. I'm on 1.17.1 from pip but even this master branch has it, right?

https://github.com/lucidrains/vector-quantize-pytorch/blob/c302cf3282161e81ebcf77a627b6d0e8bf34b069/vector_quantize_pytorch/finite_scalar_quantization.py#L162

lucidrains commented 2 months ago

@hummat oh oops, that should have been removed a long time ago

could you try 1.17.3?

hummat commented 2 months ago

@lucidrains aha, yes, in 1.17.3 it's fine :)