google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

Force `fp32` in `attention.MultiHeadDotProductAttention` for softmax operator #4008

Closed willisma closed 1 week ago

willisma commented 2 weeks ago

Hi,

As I was trying to implement mixed precision training under Flax for my project, I noticed that the force_fp32_for_softmax flag defined in attention.MultiHeadDotProductAttention does not get to pass into dot_product_attention (the default attention function).

Screenshot 2024-06-18 at 11 18 38 PM Screenshot 2024-06-18 at 11 17 55 PM Screenshot 2024-06-18 at 11 19 21 PM

I think this might lead to loose control over the softmax operator and result in some stability issues under bf16 or fp16 precision, so I wonder if there's an alternate? Thanks!

IvyZX commented 1 week ago

Hmm, it seems that the arg force_fp32_for_softmax in MultiHeadDotProductAttention wasn't really threaded through. I will make a PR to fix this. Meanwhile, maybe you can initialize your MultiHeadDotProductAttention layer with attention_fn=functools.partial(nn.dot_product_attention, force_fp32_for_softmax=True) to use the arg for now?