axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.58k stars 822 forks source link

Gemma inference CUDA illegal memory access error #1507

Open radhacr opened 5 months ago

radhacr commented 5 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

The training with the Gemma qlora config in the examples runs fine. But the inference does not produce the expected response. The generate function runs into an error instead.

Current behaviour

Runs into an CUDA error

$ accelerate launch -m axolotl.cli.inference examples/gemma/qlora.yml --lora_dir="axolotl-gemma-aplaca-qlora/"
The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_processes` was set to a value of `1`
        `--num_machines` was set to a value of `1`
        `--mixed_precision` was set to a value of `'no'`
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
[2024-04-09 17:54:59,911] [INFO] [datasets.<module>:58] [PID:28431] PyTorch version 2.2.2 available.
[2024-04-09 17:55:00,910] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
                                 dP            dP   dP
                                 88            88   88
      .d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88
      88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88
      88.  .88  .d88b.  88.  .88 88 88.  .88   88   88
      `88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP

[2024-04-09 17:55:02,137] [DEBUG] [axolotl.normalize_config:79] [PID:28431] [RANK:0] bf16 support detected, enabling for this configuration.
[2024-04-09 17:55:02,421] [INFO] [axolotl.normalize_config:182] [PID:28431] [RANK:0] GPU memory usage baseline: 0.000GB (+0.849GB misc)
[2024-04-09 17:55:02,423] [INFO] [axolotl.common.cli.load_model_and_tokenizer:50] [PID:28431] [RANK:0] loading tokenizer... mhenrichsen/gemma-7b
[2024-04-09 17:55:03,801] [DEBUG] [axolotl.load_tokenizer:252] [PID:28431] [RANK:0] EOS: 1 / <eos>
[2024-04-09 17:55:03,801] [DEBUG] [axolotl.load_tokenizer:253] [PID:28431] [RANK:0] BOS: 2 / <bos>
[2024-04-09 17:55:03,801] [DEBUG] [axolotl.load_tokenizer:254] [PID:28431] [RANK:0] PAD: 0 / <pad>
[2024-04-09 17:55:03,801] [DEBUG] [axolotl.load_tokenizer:255] [PID:28431] [RANK:0] UNK: 3 / <unk>
[2024-04-09 17:55:03,801] [INFO] [axolotl.load_tokenizer:266] [PID:28431] [RANK:0] No Chat template selected. Consider adding a chat template for easier inference.
[2024-04-09 17:55:03,801] [INFO] [axolotl.common.cli.load_model_and_tokenizer:52] [PID:28431] [RANK:0] loading model and (optionally) peft_config...
`low_cpu_mem_usage` was None, now set to True since model is quantized.
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:07<00:00,  1.81s/it]
[2024-04-09 17:55:12,361] [INFO] [axolotl.load_model:654] [PID:28431] [RANK:0] GPU memory usage after model load: 5.216GB (+0.028GB cache, +1.984GB misc)
[2024-04-09 17:55:12,368] [INFO] [axolotl.load_model:700] [PID:28431] [RANK:0] converting PEFT model w/ prepare_model_for_kbit_training
[2024-04-09 17:55:12,372] [INFO] [axolotl.load_model:709] [PID:28431] [RANK:0] converting modules to torch.bfloat16 for flash attention
[2024-04-09 17:55:12,375] [INFO] [axolotl.load_lora:853] [PID:28431] [RANK:0] found linear modules: ['o_proj', 'v_proj', 'down_proj', 'q_proj', 'up_proj', 'gate_proj', 'k_proj']
trainable params: 100,007,936 || all params: 8,637,688,832 || trainable%: 1.1578089688702506
[2024-04-09 17:55:13,584] [INFO] [axolotl.load_model:754] [PID:28431] [RANK:0] GPU memory usage after adapters: 5.588GB (+2.720GB cache, +1.984GB misc)
================================================================================
Give me an instruction (Ctrl + D to submit):
Give me some health tips.
========================================
<bos>Give me some health Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/radhachitta/axolotl/src/axolotl/cli/inference.py", line 36, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs) 
  File "/home/radhachitta/axolotl/src/axolotl/cli/inference.py", line 32, in do_cli
    do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
  File "/home/radhachitta/axolotl/src/axolotl/cli/__init__.py", line 206, in do_inference
    generated = model.generate(
  File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1190, in generate
    outputs = self.base_model.generate(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1577, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2733, in _sample
    outputs = self(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1098, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 923, in forward
    layer_outputs = decoder_layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 643, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 341, in forward
    value_states = self.v_proj(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 458, in forward
    result = result.clone()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Traceback (most recent call last):
  File "/opt/conda/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1057, in launch_command
    simple_launcher(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 673, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.10', '-m', 'axolotl.cli.inference', 'examples/gemma/qlora.yml', '--lora_dir=axolotl-gemma-aplaca-qlora/']' returned non-zero exit status 1.
```a

### Steps to reproduce

1. Fine tune gemma model with the alpaca data set with the output directory set to _axolotl-gemma-aplaca-qlora_
`accelerate launch -m axolotl.cli.train examples/gemma/qlora.yml`
2. Run the inference 
`accelerate launch -m axolotl.cli.inference examples/gemma/qlora.yml --lora_dir="axolotl-gemma-aplaca-qlora/"`

### Config yaml

```yaml
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

# huggingface repo
datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca
val_set_size: 0.1
output_dir: ./axolotl-gemma-aplaca-qlora/

adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main/4d6490b

Acknowledgements

winglian commented 5 months ago

inference should really be done with single a process using python -m axolotl.cli.inference ... instead of accelerate

radhacr commented 5 months ago

Same thing happens even without accelerate.

$ python  -m axolotl.cli.inference examples/gemma/qlora.yml --lora_dir="axolotl-gemma-aplaca-qlora/"
[2024-04-10 13:45:56,207] [INFO] [datasets.<module>:58] [PID:2747] PyTorch version 2.2.2 available.
[2024-04-10 13:45:58,395] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
                                 dP            dP   dP
                                 88            88   88
      .d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88
      88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88
      88.  .88  .d88b.  88.  .88 88 88.  .88   88   88
      `88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP

[2024-04-10 13:46:00,709] [DEBUG] [axolotl.normalize_config:79] [PID:2747] [RANK:0] bf16 support detected, enabling for this configuration.
[2024-04-10 13:46:01,031] [INFO] [axolotl.normalize_config:182] [PID:2747] [RANK:0] GPU memory usage baseline: 0.000GB (+0.849GB misc)
[2024-04-10 13:46:01,033] [INFO] [axolotl.common.cli.load_model_and_tokenizer:50] [PID:2747] [RANK:0] loading tokenizer... mhenrichsen/gemma-7b
[2024-04-10 13:46:02,413] [DEBUG] [axolotl.load_tokenizer:252] [PID:2747] [RANK:0] EOS: 1 / <eos>
[2024-04-10 13:46:02,413] [DEBUG] [axolotl.load_tokenizer:253] [PID:2747] [RANK:0] BOS: 2 / <bos>
[2024-04-10 13:46:02,413] [DEBUG] [axolotl.load_tokenizer:254] [PID:2747] [RANK:0] PAD: 0 / <pad>
[2024-04-10 13:46:02,413] [DEBUG] [axolotl.load_tokenizer:255] [PID:2747] [RANK:0] UNK: 3 / <unk>
[2024-04-10 13:46:02,413] [INFO] [axolotl.load_tokenizer:266] [PID:2747] [RANK:0] No Chat template selected. Consider adding a chat template for easier inference.
[2024-04-10 13:46:02,413] [INFO] [axolotl.common.cli.load_model_and_tokenizer:52] [PID:2747] [RANK:0] loading model and (optionally) peft_config...
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
[2024-04-10 13:46:06,189] [INFO] [accelerate.utils.modeling.get_balanced_memory:965] [PID:2747] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use mor
e memory (at your own risk).
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [01:37<00:00, 24.41s/it]
[2024-04-10 13:47:44,972] [INFO] [axolotl.load_model:654] [PID:2747] [RANK:0] GPU memory usage after model load: 5.216GB (+0.028GB cache, +1.984GB misc)
[2024-04-10 13:47:44,982] [INFO] [axolotl.load_model:700] [PID:2747] [RANK:0] converting PEFT model w/ prepare_model_for_kbit_training
[2024-04-10 13:47:44,985] [INFO] [axolotl.load_model:709] [PID:2747] [RANK:0] converting modules to torch.bfloat16 for flash attention
[2024-04-10 13:47:44,988] [INFO] [axolotl.load_lora:853] [PID:2747] [RANK:0] found linear modules: ['gate_proj', 'q_proj', 'k_proj', 'up_proj', 'down_proj', 'v_proj', 'o_proj']
trainable params: 100,007,936 || all params: 8,637,688,832 || trainable%: 1.1578089688702506
[2024-04-10 13:47:46,191] [INFO] [axolotl.load_model:754] [PID:2747] [RANK:0] GPU memory usage after adapters: 5.588GB (+2.720GB cache, +1.984GB misc)
================================================================================
Give me an instruction (Ctrl + D to submit):
Give me some health tips.
========================================
<bos>Give me some health Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/radhachitta/axolotl/src/axolotl/cli/inference.py", line 36, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/radhachitta/axolotl/src/axolotl/cli/inference.py", line 32, in do_cli
    do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
  File "/home/radhachitta/axolotl/src/axolotl/cli/__init__.py", line 206, in do_inference
    generated = model.generate(
  File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1190, in generate
    outputs = self.base_model.generate(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1577, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2733, in _sample
    outputs = self(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1098, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 923, in forward
    layer_outputs = decoder_layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 643, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 341, in forward
    value_states = self.v_proj(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 458, in forward
    result = result.clone()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Volko61 commented 5 months ago

I get the same issue with qwen

CyberNativeAI commented 3 months ago

Same issue with codestral on exactly the middle of epoch

{'loss': 0.6265, 'grad_norm': 8.632710456848145, 'learning_rate': 0.00012, 'epoch': 0.48}
 25%|███████████                                 | 6/24 [06:05<18:30, 61.71s/it]Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/random.py", line 165, in fork_rng
    yield
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 271, in backward
    outputs = ctx.run_function(*detached_inputs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py", line 611, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py", line 131, in flashattn_forward
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/peft/tuners/lora/bnb.py", line 217, in forward
    result = self.base_layer(x, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 797, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 556, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 321, in forward
    CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/bitsandbytes/functional.py", line 2535, in double_quant
    nnz = nnz_row_ptr[-1].item()
          ^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 70, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 38, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 66, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/train.py", line 170, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3250, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 2125, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 257, in backward
    with torch.random.fork_rng(
  File "/opt/conda/lib/python3.11/contextlib.py", line 158, in __exit__
    self.gen.throw(typ, value, traceback)
  File "/opt/conda/lib/python3.11/site-packages/torch/random.py", line 169, in fork_rng
    device_mod.set_rng_state(device_rng_state, device)
  File "/opt/conda/lib/python3.11/site-packages/torch/cuda/random.py", line 75, in set_rng_state
    _lazy_call(cb)
  File "/opt/conda/lib/python3.11/site-packages/torch/cuda/__init__.py", line 229, in _lazy_call
    callable()
  File "/opt/conda/lib/python3.11/site-packages/torch/cuda/random.py", line 73, in cb
    default_generator.set_state(new_state_copy)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
winglian commented 3 months ago

If you add CUDA_LAUNCH_BLOCKING=1 to your env vars, can you report the updated stack trace?

CyberNativeAI commented 3 months ago

Last update: idk what happened but it is working now! Thanks @winglian, it might be env you suggested + disabling sample_packing/pad_to_sequence_len!

@winglian thanks for your reply, it seems to crash right away with this setting. Please let me know how else I can assist.

FYI I play around with A100 80gb, doing 8bit lora with rslora on. I'm also modifying default tokenizer to suit chatml.

Update: it seems to run after setting

sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

IDK how well it'll run it tho..

Also, I have some extra data I can't use due to error too... It is chatml formatted so idk what is wrong..

[2024-06-07 02:56:22,334] [INFO] [axolotl.callbacks.on_train_begin:785] [PID:2737] [RANK:0] The Axolotl config has been saved to the WandB run under files.
  0%|                                                    | 0/50 [00:00<?, ?it/s][2024-06-07 02:56:22,336] [INFO] [axolotl.utils.samplers.multipack._len_est:185] [PID:2737] [RANK:0] packing_efficiency_estimate: 1.0 total_num_tokens per device: 641630
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 70, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 38, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/cli/train.py", line 66, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/work/axolotl/src/axolotl/train.py", line 170, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3250, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 2125, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 321, in backward
    _flash_attn_varlen_backward(
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 181, in _flash_attn_varlen_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
CyberNativeAI commented 3 months ago

I was able to train model with @winglian suggested env:

CUDA_LAUNCH_BLOCKING=1

And the following config:

base_model: mistralai/Codestral-22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: long_sys_msg_all_data_v000.jsonl
    conversation: chatml
    type: sharegpt

test_datasets:
  - path: 000refined_neo_dataset_v2eval.jsonl
    split: train
    conversation: chatml
    type: sharegpt

chat_template: chatml

adapter: lora
peft_use_rslora: true

lora_r: 64
lora_alpha: 32
lora_dropout: 0
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

dataset_prepared_path:
val_set_size: 0
output_dir: Colibri22bOut

sequence_len: 4096
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

save_safetensors: true

wandb_project: Colibri22b
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

save_total_limit: 1
gradient_accumulation_steps: 6
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: True
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
save_strategy: "no"

warmup_steps: 10
evals_per_epoch: 2
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
  bos_token: "<s>"
  eos_token: "<|im_end|>"
  unk_token: "<unk>"
lora_modules_to_save:
 - embed_tokens
 - lm_head
tokens:
- "<|im_start|>"

Eval loss looks delicious! This model is hella smart!! image

And train loss: image

Training dataset is only 352 refined examples.

Hmm... when I use the model it is not good, will try to figure out why next time.