lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.79k stars 4.54k forks source link

srcIndex < srcSelectDimSize error when tuning with Llama 2 #2038

Closed alwayshalffull closed 1 year ago

alwayshalffull commented 1 year ago

I am getting the following error when trying to fine-tune the 7B and 13B models from a Llama 2 base:

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [29,0,0] Assertion srcIndex < srcSelectDimSize failed. ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [30,0,0] Assertion srcIndex < srcSelectDimSize failed. ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [31,0,0] Assertion srcIndex < srcSelectDimSize failed.

I'm using the same dataset and environment configurations as I was when successfully fine-tuning from Llama 1 base.

A few things that I've tried:

Has anyone else run into this, and if so what were your steps to mitigate? I think it might be related to an issue with the tokenizer. Thanks!

Environment

Full Traceback:

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [187,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "/home/ubuntu/sky_workdir/fastchat/train/train_mem.py", line 13, in <module>
    train()
  File "/home/ubuntu/sky_workdir/fastchat/train/train.py", line 277, in train
    trainer.train()
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/sky_workdir/fastchat/train/llama_flash_attn_monkey_patch.py", line 85, in forward
    x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 108, in unpad_input
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232190 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232191 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232192 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232193 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232194 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232196 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2232197 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 5 (pid: 2232195) of binary: /home/ubuntu/miniconda3/envs/fastchat/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/fastchat/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/miniconda3/envs/fastchat/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
fastchat/train/train_mem.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-07-20_22:16:53
  host      : 140-238-230-107
  rank      : 5 (local_rank: 5)
  exitcode  : -6 (pid: 2232195)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2232195
========================================================
Building synchronization state...
Starting synchronization...

Here's the command I'm using (based on the example script train_vicuna_7b.sh):

torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \
    --model_name_or_path meta-llama/Llama-2-7b-hf  \
    --data_path ~/data.json \
    --bf16 True \
    --output_dir ~/.checkpoints \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "steps" \
    --eval_steps 1500 \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 8 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.04 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap offload" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 4096 \
    --gradient_checkpointing True \
    --lazy_preprocess True
merrymercy commented 1 year ago

Try https://github.com/lm-sys/FastChat/blob/main/fastchat/data/optional_replace.py to clean your data

alwayshalffull commented 1 year ago

That did the trick, thanks for the suggestion!