huggingface / alignment-handbook

Robust recipes to align language models with human and AI preferences
https://huggingface.co/HuggingFaceH4
Apache License 2.0
4.28k stars 367 forks source link

DPO fine-tuning errors out on Yi 34B (Assertion `srcIndex < srcSelectDimSize` failed) #77

Open cvetanovskaa opened 7 months ago

cvetanovskaa commented 7 months ago

The script errors out only with Yi 34B Chat. I have tried Llama2 7/13B and SUSTech/SUS-Chat-34B and they all work. Yi 34B Chat has consistently been running into the following issue:

Traceback:

Assertion `srcIndex < srcSelectDimSize` failed.
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [54,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [55,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [56,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [58,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [59,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [60,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [234,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cureturn self._call_impl(*args, **kwargs):1292

----

  File "/home/ec2-user/SageMaker/alignment-handbook/scripts/dpo/run_dpo.py", line 238, in <module>
    main()
  File "/home/ec2-user/SageMaker/alignment-handbook/scripts/dpo/run_dpo.py", line 186, in main
    train_result = dpo_trainer.train()
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 594, in compute_loss
    loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 545, in get_batch_metrics
    ) = self.concatenated_forward(model, batch)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 511, in concatenated_forward
    all_logits = model(
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1814, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py", line 1003, in forward
    return self.base_model(
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 106, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 879, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: CUDA error: device-side assert triggered

---

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f675c18e617 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f675c14998d in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f675c24a128 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x80 (0x7f66dade5240 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7f66dade9068 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::workCleanupLoop() + 0x250 (0x7f66dadff900 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x78 (0x7f66dadffc08 in /home/ec2-user/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xbaacf (0x7f6767877acf in /lib64/libstdc++.so.6)
frame #8: <unknown function> + 0x744b (0x7f677763a44b in /lib64/libpthread.so.0)
frame #9: clone + 0x3f (0x7f6776a2652f in /lib64/libc.so.6)

terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: CUDA error: device-side assert triggered

Environment: