axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

TinyLlama pretrain fails, but SFT works -- CUDA error: an illegal memory access was encountered #1753

Closed chromecast56 closed 1 month ago

chromecast56 commented 1 month ago

Please check that this issue hasn't been reported before.

Expected Behavior

Should run without errors

Current behaviour

[rank3]: Traceback (most recent call last):
[rank3]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank3]:   File "<frozen runpy>", line 88, in _run_code
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 73, in <module>
[rank3]:     fire.Fire(do_cli)
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
[rank3]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank3]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
[rank3]:     component, remaining_args = _CallAndUpdateTrace(
[rank3]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank3]:     component = fn(*varargs, **kwargs)
[rank3]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 39, in do_cli
[rank3]:     return do_train(parsed_cfg, parsed_cli_args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 68, in do_train
[rank3]:     return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/train.py", line 170, in train
[rank3]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
[rank3]:     return inner_training_loop(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
[rank3]:     tr_loss_step = self.training_step(model, inputs)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 3307, in training_step
[rank3]:     loss = self.compute_loss(model, inputs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/core/trainer_builder.py", line 539, in compute_loss
[rank3]:     return super().compute_loss(model, inputs, return_outputs=return_outputs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 3338, in compute_loss
[rank3]:     outputs = model(**inputs)
[rank3]:               ^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1593, in forward
[rank3]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank3]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1411, in _run_ddp_forward
[rank3]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank3]:     return model_forward(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank3]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank3]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
[rank3]:     outputs = self.model(
[rank3]:               ^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 809, in llama_model_forward
[rank3]:     layer_outputs = torch.utils.checkpoint.checkpoint(
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
[rank3]:     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint
[rank3]:     return CheckpointFunction.apply(function, preserve, *args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 262, in forward
[rank3]:     outputs = run_function(*args)
[rank3]:               ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 803, in custom_forward
[rank3]:     return module(
[rank3]:            ^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 901, in forward
[rank3]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank3]:                                                           ^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 491, in flashattn_forward
[rank3]:     qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
[rank3]:                                                               ^^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py", line 611, in generate_qkv
[rank3]:     q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
[rank3]:                                                      ^^^^^^^^^^^^
[rank3]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 110, in unpad_input
[rank3]:     indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank3]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank3]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[rank3]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Steps to reproduce

Latest commit 78e12f8, using pip/conda,torch==2.3.1, CUDA 12.2.

Run command accelerate launch -m axolotl.cli.train examples/tiny-llama/pretrain.yml (slightly modified, see yml below). Running on an 8xH100 machine.

Config yaml

base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset: 
  - path: jamesliu1/c4
# datasets:
#   - path: tatsu-lab/alpaca
#     type: alpaca

# pretraining_dataset:
#   - path: c4
#     name: en
#     type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/model-out

sequence_len: 2048
sample_packing: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

Possible solution

Interestingly, running with a SFT dataset does fine (eg. the commented tatsu-lab/alpaca). Not sure what the difference is for the pretraining case. Any help is appreciated!

Which Operating Systems are you using?

Python Version

3.11

axolotl branch-commit

main/78e12f8

Acknowledgements

winglian commented 1 month ago

Do you get the same error when using the default/original c4 dataset ?

chromecast56 commented 1 month ago

I get a different error:


[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 73, in <module>
[rank0]:     fire.Fire(do_cli)
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 39, in do_cli
[rank0]:     return do_train(parsed_cfg, parsed_cli_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/axolotl/src/axolotl/cli/train.py", line 68, in do_train
[rank0]:     return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/axolotl/src/axolotl/train.py", line 170, in train
[rank0]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/transformers/trainer.py", line 2230, in _inner_training_loop
[rank0]:     for step, inputs in enumerate(epoch_iterator):
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/accelerate/data_loader.py", line 677, in __iter__
[rank0]:     next_batch, next_batch_info = self._fetch_batches(main_iterator)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jamesliu/anaconda3/envs/axolotl/lib/python3.11/site-packages/accelerate/data_loader.py", line 635, in _fetch_batches
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.either pass `dispatch_batches=False` and have each process fetch its own batch  or pass `split_batches=True`. By doing so, the main process will fetch a full batch and slice it into `num_processes` batches for each process.```
chromecast56 commented 1 month ago

@winglian As a bandaid, would there be a way for me to fully load the smaller pretrain dataset jamesliu1/c4 using datasets: instead of pretrained_dataset (my usecase is mending after applying a slightly lossy compression to a base model)? The only issue is that there isn't a pretrain type in datasets:.

thechuong98 commented 1 month ago

i'm facing the same issue!

winglian commented 1 month ago

@winglian As a bandaid, would there be a way for me to fully load the smaller pretrain dataset jamesliu1/c4 using datasets: instead of pretrained_dataset (my usecase is mending after applying a slightly lossy compression to a base model)? The only issue is that there isn't a pretrain type in datasets:.

for small pretrain style datasets, you can use type: completion

deter3 commented 1 month ago

I had the same issues , on 06-Aug-2024 , I am able to run with type: completion , but 07-Aug-2024 , I am keeping get the same error of "CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call " .

I tried to go back the commit of 06-Aug-2024 , it works again .

git clone https://github.com/OpenAccess-AI-Collective/axolotl

cd axolotl

git fetch origin

git checkout 203816f7b4de020c40708e4e61847b0716189380