lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

cast for bf16 issue #100

Closed xyzhang626 closed 5 months ago

xyzhang626 commented 5 months ago

Current version under bf16 training will tiger the following mismatched type error in einsum.

File "/opt/conda/lib/python3.10/site-packages/vector_quantize_pytorch/lookup_free_quantization.py", line 220, in forward
    distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
  File "/opt/conda/lib/python3.10/site-packages/torch/functional.py", line 377, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: expected scalar type Float but found BFloat16

Add an explicit cast to solve it.

Actually it should be covered by the autocast mechanism of pytorch. But somehow the autocast does not work as expected. There might exist a more elegant fix to directly trigger the autocast.

lucidrains commented 5 months ago

@xyzhang626 ah, let us just do the quantization in float32 for now

can you let me know if 1.12.10 works?

xyzhang626 commented 5 months ago

@xyzhang626 ah, let us just do the quantization in float32 for now

can you let me know if 1.12.10 works?

@lucidrains I just tried it, 1.12.10 works for bf16.

Just curious, is there any specific reason for using fp32 in the quantization?

lucidrains commented 5 months ago

@xyzhang626 just being cautious, as in vector quantization it makes a difference. if in a residual setup, probably still matters too depending on how many residual layers

in standalone LFQ, not sure! i'd welcome any experiments showing f16 works fine, in which case i'll remove the restriction

xyzhang626 commented 5 months ago

Cool that makes sense. In my experiment, bf16 works well and fp16 seems fine until GAN loss is added.

lucidrains commented 5 months ago

@xyzhang626 oh interesting, do you mean the adversarial loss from the VQ-GAN VAE setup? bf16 was fine?

lucidrains commented 5 months ago

bf16, if i'm not mistakened, has lower precision, so have to be cautious for a residual quantization setup

xyzhang626 commented 5 months ago

Yes in my VQGAN experiment bf16 is fine, fp16 is not. bf16 has lower precision but larger range compared to fp16. I think most of LLMs today trained with bf16.

I have no touch with residual quantization. It's surely reasonable to keep full precision to be cautious. If the casting to full precision in the quantization part under a half-precision network training does not obviously cost, that's totally a good choice.

lucidrains commented 5 months ago

@xyzhang626 thanks for clarifying! let me look into disabling autocast but only for f16