Open anti4m opened 4 years ago
Hi,
The library currently doesn't support Mixed Precision or specifically 16-bit floating operations. We are working on half-precision Cuda kernels and it will be made available in the future. For now, I will tag this as an enhancement.
Just FYI: You should still be able to use half-precision during inference with Recurrent Attention which only uses native Pytorch operations.
Thanks, Apoorv
Attempting to use the non-causal version of LinearAttention (that doesn't need a CUDA kernel) with PyTorch AMP via the autoscaler context. Works ok for 100-1000 steps but eventually an NaN trickles in somewhere...
I've narrowed it down to the fact that fp16 covers a much smaller range of values than fp32 (https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#hpformat). Summing across the seq len axis of the keys (https://github.com/idiap/fast-transformers/blob/4b345972b73b2dca10ff2b009a9fc9cc06a07133/fast_transformers/attention/linear_attention.py#L75) can lead to fairly large values which can cause overflow in fp16 (or underflow if too low and the eps value doesn't get cast correctly to the minimum representable value in fp16).
Scaling (see #11 -- I've implemented something similar in a local fork) can help resolve the overflow issue but I still get a NaN coming in post 2K steps or so... not sure if it's under- or over-flow.
Just tagging so that you are aware that AMP might cause issues.
[@sidnarayanan as he and I have been discussing]
@ncilfone and @sidnarayanan, sorry for the late reply, this sounds like something important that we should investigate.
Do you have a small toy example that replicates the problem?
The simplest solution would be to have a stabilize boolean flag so we can spend a few extra cycles to stabilize the attention computation. Similar to how softmax is shift independent, linear attention is scaling independent so dividing with a large number e.g. the max norm of Q or K would result in no overflow. Making sure that eps is properly added by doing max(...., self.eps) instead of + would also ensure that we have no 0/0.
Cheers, Angelos
Hello @apoorv2904 and @angeloskath, was there any update on fp16 support for CausalDotProduct
?
Thanks
The FP16 support for CausalDotProduct
would be highly beneficial for my use case as well. Currently converting back and forth between FP32/16 and causes 10x slowdown during backward passes. Thanks!
Sorry for taking long guys. The support will come when a generic kernel is implemented for CausalDotProduct
, it is not too much work but still.
Thank you for your patience all these months.
Angelos
Has there any progress on fp16 kernel been made?@angeloskath
Hi,
I tried running CausalLinear attention while using PyTorch Automatic Mixed Precision. I got an error, saying "line 44, in forward CausalDotProduct.dot[device.type]( RuntimeError: expected scalar type Float but found Half"
Is this a bug? Or does your library not offer support for 16-bit precision floats?
Thank you for your time.