Closed willisma closed 1 week ago
Hmm, it seems that the arg force_fp32_for_softmax
in MultiHeadDotProductAttention
wasn't really threaded through. I will make a PR to fix this.
Meanwhile, maybe you can initialize your MultiHeadDotProductAttention
layer with attention_fn=functools.partial(nn.dot_product_attention, force_fp32_for_softmax=True)
to use the arg for now?
Hi,
As I was trying to implement mixed precision training under Flax for my project, I noticed that the
force_fp32_for_softmax
flag defined inattention.MultiHeadDotProductAttention
does not get to pass intodot_product_attention
(the default attention function).I think this might lead to loose control over the softmax operator and result in some stability issues under
bf16
orfp16
precision, so I wonder if there's an alternate? Thanks!