AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Mask is being ignored when cudnn_flash_attention is used #878

Open finbarrtimbers opened 2 weeks ago

finbarrtimbers commented 2 weeks ago

From the TransformerEngine docs:

mask in call is ignored for ‘no_mask’ and ‘causal’.

I think that attn_mask_type should be set to causal_padding, or else it'll ignore the mask being passed on line 370.

bvandermoon commented 1 day ago

Thanks Finbarr. Created a PR to address this