pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.06k stars 375 forks source link

[Bug] Unusual CPU overhead of SDPA call on H100 on torch nightly #1652

Open Jackmin801 opened 1 week ago

Jackmin801 commented 1 week ago

Issue identified: cuDNN SDPA JIT recompiles when the context length changes. This results in training that does not use packing to keep recompiling, resulting in the observed 500ms overhead.


There seems to be something unusual about the performance characteristics of the attention implementation when running Llama models on H100. Namely that there is an unusually long overhead in the first sdpa forward and backward. This isnt observed in other cards (tested 4090 and 3090) which use the flash attention implementation.

image image

A warning appears in the output about mistmatched strides in the output and grad output from the cuDNN SDPA bwd:

/root/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Steps to reproduce

1. Clone repo at this branch

git clone --branch feat-activation-offloading-streams-dist git@github.com:Jackmin801/torchtune
cd torchtune

2. Dev Install with torch nightly

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip install torchao
pip install matplotlib pandas numpy
pip install -e ".[dev]"

3. Login to huggingface and download model

huggingface-cli login --token <HF_TOKEN>
tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/*.pth"

4. Run single device

#tune run --nnodes 1 --nproc-per-node 2 lora_finetune_distributed --config recipes/configs/llama3_1/8B_lora.yaml \
tune run lora_finetune_single_device --config recipes/configs/llama3_1/8B_lora.yaml \
  optimizer_in_bwd=False \
  enable_activation_offloading=False \
  enable_activation_checkpointing=True \
  compile=False \
  model.lora_attn_modules="[q_proj,v_proj]" \
  model.apply_lora_to_mlp=False \
  model.apply_lora_to_output=False \
  model.lora_rank=8 \
  model.lora_alpha=16 \
  dataset.source=Yukang/LongAlpaca-12k \
  dataset.packed=False \
  dataset.split=train[:10%] \
  dataset.train_on_input=True \
  tokenizer.max_seq_len=8192 \
  metric_logger=torchtune.training.metric_logging.StdoutLogger \
  metric_logger.project=recipe_profiling \
  log_every_n_steps=1 \
  log_peak_memory_stats=True \
  gradient_accumulation_steps=1 \
  max_steps_per_epoch=4 \
  epochs=1 \
  batch_size=2 \
  metric_logger.name=llama3__qlora__seqlen_8192__act_ckpt_True__act_off_True__bs2 \
  profiler.enabled=True \
  profiler.profile_memory=True \
  profiler.with_stack=True \
  profiler.wait_steps=0 \
  profiler.warmup_steps=2 \
  profiler.active_steps=2 \
  profiler.num_cycles=1

5. Observe the trace

Weird right?

felipemello1 commented 1 week ago

thanks for reporting it! Quick question: Do you know if that happens in the very first batch vs in the first layer of every single batch?

Because if it happens only in the very first batch, then it may just be the time it takes to compile.

edit: just realized that the config has compile=False. I wonder if it could have to do with pos_embedding initialization

Jackmin801 commented 1 week ago

It seems to occur every batch. hrmm dont think its about pos_embeds, otherwise it would happen for flash attn too?

Jackmin801 commented 1 week ago

Im wondering if its a regression from cudnn so im building pytorch with older cudnn versions to see if anything changes. Changing cudnn version to 9.2.0 doesnt seem to help. The good news is that the dispatcher only seems to select CUDNN SDPA in the nightly. The current stable release (2.4.1) selects flash sdpa which works fine.

Jackmin801 commented 1 week ago

Seems to be from the CUDNN_ATTENTION implementation of sdpa. With the viable/strict build of pytorch, you can toggle the bug by forcing FLASH or CUDNN implementation.

    from torch.nn.attention import sdpa_kernel, SDPBackend
    with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
        recipe.train()
image
    from torch.nn.attention import sdpa_kernel, SDPBackend
    with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
        recipe.train()
image

@eqy @drisspg @Skylion007

eqy commented 1 week ago

This overhead might be expected as the cuDNN SDPA backend is doing JIT compilation. Could you share a minimal example e.g., exercising the shape(s) that show the issue? We can work with cuDNN to try and get the overhead reduced.

If it is occurring every batch as in "It seems to occur every batch" with identical shapes/strides/etc then this would be a bug, as the compiled kernel should be cached. Can you confirm this?

janeyx99 commented 1 week ago

Ah, it does only occur on the first call in the fwd and the first call in the bwd, so that caching would make sense.

Jackmin801 commented 6 days ago

Ah that might be it. The shapes are different between the first batch and the second batch. So padding to max length should fix this right?

image
drisspg commented 6 days ago

@eqy does CuDNN jit compile for every updated sequence length? That seems non ideal

eqy commented 6 days ago

@drisspg at present it would trigger recompilation but I'll reach to cuDNN to see if we can get reuse here