axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.7k stars 851 forks source link

Zero loss and nan grad_norm when Flash Attention is enabled #1706

Open fgdfgfthgr-fox opened 4 months ago

fgdfgfthgr-fox commented 4 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

I expect similar loss and grad_norm when training a model with the same setting regardless whether flash attention is enabled or not.

Current behaviour

Currently, during training steps (right from the start), I can see messages of {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 6.545084971874738e-06, 'epoch': 0.4} for few steps, before a

  File "/home/huada524/ondemand/data/sys/myjobs/projects/default/1/huada524-prune-env/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 110, in unpad_input
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
RuntimeError: CUDA error: an illegal memory access was encountered

error appear and the training stops.

However, if flash attention is disabled with flash_attention: false, then the network trains normally. {'loss': 3.0972, 'grad_norm': 0.76171875, 'learning_rate': 3.4549150281252635e-06, 'epoch': 0.6}

Steps to reproduce

  1. I Installed my axolotl on a remote cluster with 3x L40 graphic cards with slurm, using the following script:
    
    module load python
    module load cuda

echo "Setting up python venv..." python -m venv venv source venv/bin/activate python -m pip install --upgrade pip pip install -U wheel pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 -I pip install ninja export TORCH_CUDA_ARCH_LIST="8.6;8.9" export CUDA_VISIBLE_DEVICES=2 export LD_LIBRARY_PATH=/home/huada524/ondemand/data/sys/myjobs/projects/default/1/venv/lib64/python3.10/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH pip install -v -U "git+https://github.com/facebookresearch/xformers.git@main#egg=xformers"

cd axolotl git pull pip install packaging pip install -e '.[flash-attn,deepspeed]'

I manually disabled xformers installation from axolotl/requirements.txt so it won't attempt to override the one I just compiled with.

I also have to apply this patch https://github.com/microsoft/DeepSpeed/issues/5603 to make sure axolotl would launch

cd ..

2. I started the training using the script below:

module load python module load cuda

source venv/bin/activate

export CUDA_VISIBLE_DEVICES=2

export WANDB_API_KEY=xxxxxxx export LD_LIBRARY_PATH=/home/huada524/ondemand/data/sys/myjobs/projects/default/1/venv/lib64/python3.10/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH

accelerate launch -m axolotl.cli.train config_llama3_40B_dora.yaml

Note the model I am training the lora with is meta's llama-3-70B model with some of its layers removed.
GPU are running on a CUDA version of 12.5, while the loaded module is 12.3.
****************************************
**** Axolotl Dependency Versions *****
  accelerate: 0.30.1         
        peft: 0.11.1         
transformers: 4.41.1         
         trl: 0.8.7.dev0     
       torch: 2.4.0.dev20240610+cu124
bitsandbytes: 0.43.1         
****************************************

### Config yaml

```yaml
base_model: /home/huada524/ondemand/data/sys/myjobs/projects/default/1/PruneMe/slice_with_mergekit/merged
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

pretraining_dataset:
  - path: HuggingFaceFW/fineweb-edu
    name: default
    type: completion
val_set_size: 0.05
max_steps: 10
output_dir: ./outputs/out

adapter: qlora
lora_r: 8
lora_alpha: 4
lora_dropout: 0.0
lora_target_linear: true
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
#  - lm_head
peft_use_dora: false
lora_model_dir:

sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true

wandb_mode: online
wandb_project: Creating Llama-3-40B
wandb_entity:
wandb_watch:
wandb_name: Experimental_Runs

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 1e-5
#loraplus_lr_embedding: 1e-6

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 0
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
#  - full_shard
#  - auto_wrap
fsdp_config:
#  fsdp_limit_all_gathers: true
#  fsdp_sync_module_states: true
#  fsdp_offload_params: true
#  fsdp_use_orig_params: false
#  fsdp_cpu_ram_efficient_loading: true
#  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
#  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
#  fsdp_state_dict_type: FULL_STATE_DICT
#  fsdp_sharding_strategy: FULL_SHARD
special_tokens:
  pad_token: <|end_of_text|>

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

5783839c6e29bb148041338772040c85aaae4646

Acknowledgements

bofei5675 commented 1 month ago

Get same error

saucam commented 2 weeks ago

I think I get the same error

[rank22]:   File "/opt/conda/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 212, in pad_input
[rank22]:     output = index_put_first_axis(hidden_states, indices, batch * seqlen)
[rank22]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank22]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
[rank22]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank22]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank22]:   File "/opt/conda/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 51, in forward
[rank22]:     output[indices] = values
[rank22]:     ~~~~~~^^^^^^^^^
[rank22]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank22]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank22]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[rank22]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f39d2711897 in /opt/conda/lib/python3.11/site-packages/torch/lib/libc10.so)