AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Use attn_mask_type of causal_padding for cudnn_flash_attention #913

Closed bvandermoon closed 1 month ago

bvandermoon commented 1 month ago

Notes

Update attn_mask_type so that it is not ignored. Related to issue 878.

Note that this change currently only impacts workloads using cudnn_flash_te attention.

Testing

Screenshot 2024-10-01 at 2 06 07 PM

gobbleturk commented 1 month ago

Notes

Update attn_mask_type so that it is not ignored. Related to issue 878.

Note that this change currently only impacts workloads using cudnn_flash_te attention.

Testing

Trained 10 steps on GPUs with cudnn_flash_te enabled. Saw the same high-level output (step time, loss, etc.) before and after this change.

Note: The loss was suspicious since it dropped from 10.307 to 0 after the third step. But this happened before and after this change so it appears to be caused by something else

I assume the loss is dropping so quickly because you are using synthetic data (dataset_type=synthetic), you can use real data to measure the loss. Ideally we would test correctness with something like our golden logits test

bvandermoon commented 1 month ago

Notes

Update attn_mask_type so that it is not ignored. Related to issue 878. Note that this change currently only impacts workloads using cudnn_flash_te attention.

Testing

Trained 10 steps on GPUs with cudnn_flash_te enabled. Saw the same high-level output (step time, loss, etc.) before and after this change. Note: The loss was suspicious since it dropped from 10.307 to 0 after the third step. But this happened before and after this change so it appears to be caused by something else

I assume the loss is dropping so quickly because you are using synthetic data (dataset_type=synthetic), you can use real data to measure the loss. Ideally we would test correctness with something like our golden logits test

Thanks Matt. That was the issue with the loss. Also ran the golden logits test and it passed. Updated the description with the test