pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.28k stars 115 forks source link

Modify FLOPs in MFU calculation for casual mask when using FlashAttention. #341

Closed Yuxin-CV closed 1 month ago

Yuxin-CV commented 1 month ago

Hi, I suggest we modify the FLOPs calculation in the MFU according to the FlashAttention benchmark script.

Specifically, the current calculation for the casual mask can exceed 100% MFU for seq_len = 16k (189 * 2 / 312 = 1.21), which is inaccurate. The FLOPs for the casual mask setting should be divided by 2 when using FlashAttention.

flash2_a100_fwd_bwd_benchmark
awgu commented 1 month ago

There was some past discussion on this (https://github.com/pytorch/torchtitan/pull/280).