idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Support for 16-bit Floats #32

Open anti4m opened 4 years ago

anti4m commented 4 years ago

Hi,

I tried running CausalLinear attention while using PyTorch Automatic Mixed Precision. I got an error, saying "line 44, in forward CausalDotProduct.dot[device.type]( RuntimeError: expected scalar type Float but found Half"

Is this a bug? Or does your library not offer support for 16-bit precision floats?

Thank you for your time.

apoorv2904 commented 4 years ago

Hi,

The library currently doesn't support Mixed Precision or specifically 16-bit floating operations. We are working on half-precision Cuda kernels and it will be made available in the future. For now, I will tag this as an enhancement.

Just FYI: You should still be able to use half-precision during inference with Recurrent Attention which only uses native Pytorch operations.

Thanks, Apoorv

ncilfone commented 4 years ago

Attempting to use the non-causal version of LinearAttention (that doesn't need a CUDA kernel) with PyTorch AMP via the autoscaler context. Works ok for 100-1000 steps but eventually an NaN trickles in somewhere...

I've narrowed it down to the fact that fp16 covers a much smaller range of values than fp32 (https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#hpformat). Summing across the seq len axis of the keys (https://github.com/idiap/fast-transformers/blob/4b345972b73b2dca10ff2b009a9fc9cc06a07133/fast_transformers/attention/linear_attention.py#L75) can lead to fairly large values which can cause overflow in fp16 (or underflow if too low and the eps value doesn't get cast correctly to the minimum representable value in fp16).

Scaling (see #11 -- I've implemented something similar in a local fork) can help resolve the overflow issue but I still get a NaN coming in post 2K steps or so... not sure if it's under- or over-flow.

Just tagging so that you are aware that AMP might cause issues.

[@sidnarayanan as he and I have been discussing]

angeloskath commented 4 years ago

@ncilfone and @sidnarayanan, sorry for the late reply, this sounds like something important that we should investigate.

Do you have a small toy example that replicates the problem?

The simplest solution would be to have a stabilize boolean flag so we can spend a few extra cycles to stabilize the attention computation. Similar to how softmax is shift independent, linear attention is scaling independent so dividing with a large number e.g. the max norm of Q or K would result in no overflow. Making sure that eps is properly added by doing max(...., self.eps) instead of + would also ensure that we have no 0/0.

Cheers, Angelos

gulnazaki commented 3 years ago

Hello @apoorv2904 and @angeloskath, was there any update on fp16 support for CausalDotProduct?

Thanks

calclavia commented 3 years ago

The FP16 support for CausalDotProduct would be highly beneficial for my use case as well. Currently converting back and forth between FP32/16 and causes 10x slowdown during backward passes. Thanks!

angeloskath commented 3 years ago

Sorry for taking long guys. The support will come when a generic kernel is implemented for CausalDotProduct, it is not too much work but still.

Thank you for your patience all these months.

Angelos

15805383399 commented 3 years ago

Has there any progress on fp16 kernel been made?@angeloskath