AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Use attn_mask_type of causal_padding for cudnn_flash_attention #913

Open bvandermoon opened 1 day ago

bvandermoon commented 1 day ago

Notes

Update attn_mask_type so that it is not ignored. Related to issue 878.

Note that this change currently only impacts workloads using cudnn_flash_te attention.

Testing

Trained 10 steps on GPUs with cudnn_flash_te enabled. Saw the same high-level output (step time, loss, etc.) before and after this change.

Note: The loss was suspicious since it dropped from 10.307 to 0 after the third step. But this happened before and after this change so it appears to be caused by something else

gobbleturk commented 14 hours ago

Notes

Update attn_mask_type so that it is not ignored. Related to issue 878.

Note that this change currently only impacts workloads using cudnn_flash_te attention.

Testing

Trained 10 steps on GPUs with cudnn_flash_te enabled. Saw the same high-level output (step time, loss, etc.) before and after this change.

Note: The loss was suspicious since it dropped from 10.307 to 0 after the third step. But this happened before and after this change so it appears to be caused by something else

I assume the loss is dropping so quickly because you are using synthetic data (dataset_type=synthetic), you can use real data to measure the loss. Ideally we would test correctness with something like our golden logits test