Open Jackmin801 opened 1 week ago
thanks for reporting it! Quick question: Do you know if that happens in the very first batch vs in the first layer of every single batch?
Because if it happens only in the very first batch, then it may just be the time it takes to compile.
edit: just realized that the config has compile=False. I wonder if it could have to do with pos_embedding initialization
It seems to occur every batch. hrmm dont think its about pos_embeds, otherwise it would happen for flash attn too?
Im wondering if its a regression from cudnn so im building pytorch with older cudnn versions to see if anything changes. Changing cudnn version to 9.2.0 doesnt seem to help. The good news is that the dispatcher only seems to select CUDNN SDPA in the nightly. The current stable release (2.4.1) selects flash sdpa which works fine.
Seems to be from the CUDNN_ATTENTION implementation of sdpa. With the viable/strict
build of pytorch, you can toggle the bug by forcing FLASH or CUDNN implementation.
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
recipe.train()
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
recipe.train()
@eqy @drisspg @Skylion007
This overhead might be expected as the cuDNN SDPA backend is doing JIT compilation. Could you share a minimal example e.g., exercising the shape(s) that show the issue? We can work with cuDNN to try and get the overhead reduced.
If it is occurring every batch as in "It seems to occur every batch" with identical shapes/strides/etc then this would be a bug, as the compiled kernel should be cached. Can you confirm this?
Ah, it does only occur on the first call in the fwd and the first call in the bwd, so that caching would make sense.
Ah that might be it. The shapes are different between the first batch and the second batch. So padding to max length should fix this right?
@eqy does CuDNN jit compile for every updated sequence length? That seems non ideal
@drisspg at present it would trigger recompilation but I'll reach to cuDNN to see if we can get reuse here
Issue identified: cuDNN SDPA JIT recompiles when the context length changes. This results in training that does not use packing to keep recompiling, resulting in the observed 500ms overhead.
There seems to be something unusual about the performance characteristics of the attention implementation when running Llama models on H100. Namely that there is an unusually long overhead in the first sdpa forward and backward. This isnt observed in other cards (tested 4090 and 3090) which use the flash attention implementation.
A warning appears in the output about mistmatched strides in the output and grad output from the cuDNN SDPA bwd:
Steps to reproduce
1. Clone repo at this branch
2. Dev Install with torch nightly
3. Login to huggingface and download model
4. Run single device
5. Observe the trace
Weird right?