huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.16k stars 26.58k forks source link

../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [267,0,0], thread: [25,0,0] Assertion `srcIndex < srcSelectDimSize` failed. #33985

Open JHW5981 opened 4 days ago

JHW5981 commented 4 days ago

System Info

Who can help?

@muellerzr @SunMarc

Information

Tasks

Reproduction

The multi-card training on the 20 series and 30 series works fine. However, when using multiple 4090 cards for training, the following error occurs: ../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [267,0,0], thread: [20,0,0] Assertion srcIndex < srcSelectDimSize failed. But there is no issue when using a single 4090 card.

  1. I use chatgpt to generate a test training code on my environment setting, the code shows below:
    
    from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
    from datasets import Dataset

model_name = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) model.resize_token_embeddings(len(tokenizer))

data = { "text": ["This is a positive example", "This is a negative example"] 50, "label": [1, 0] 50
}

dataset = Dataset.from_dict(data)

def preprocess_function(examples): return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=64)

encoded_dataset = dataset.map(preprocess_function, batched=True)

training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=8,
num_train_epochs=1,
logging_dir="./logs", logging_steps=10, evaluation_strategy="no" )

trainer = Trainer( model=model, args=training_args, train_dataset=encoded_dataset, )

trainer.train()

print("Training completed.")

2. When I use only one gpu, that's:

CUDA_VISIBLE_DEVICES=5 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python test.py

Everything is ok.

3. When I use multi-gpu, that's:

CUDA_VISIBLE_DEVICES=5,6 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python test.py

Error comes:
/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
...
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:00<00:00, 5947.60 examples/s]
/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of πŸ€— Transformers. Use `eval_strategy` instead
  warnings.warn(
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                                                                                                                                                            | 0/7 [00:00<?, ?it/s]/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
 14%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                                                                          | 1/7 [00:03<00:22,  3.74s/it]../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [126,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [126,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
......
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [159,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "/home/jihuawei2/projects/AceRead/model/test.py", line 47, in <module>
    trainer.train()
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 1948, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 3328, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 3373, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/_utils.py", line 705, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 1695, in forward
    outputs = self.bert(
              ^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 1107, in forward
    extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py", line 449, in _prepare_4d_attention_mask_for_sdpa
    if not is_tracing and torch.all(mask == 1):
                          ^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

 14%|β–ˆβ–        | 1/7 [00:04<00:26,  4.38s/it]    

Expected behavior

I believe this error is related to the package for parallel computation in Huggingface Trainer. I hope that eventually, I can achieve multi-card training on the 4090.

ArthurZucker commented 4 days ago

Hey! Would recommend you to run on CPU first, will tell you where this issue comes from! πŸ€— Leaving the community to help as well

JHW5981 commented 3 days ago

@ArthurZucker Thank you for your reply. The code runs without any issues on the CPU, and there are no problems with parallel training on four 3090s or four 2080s, nor with training on a single 4090. However, as soon as I use the trainer and set CUDA_VISIBLE_DEVICES to multiple GPUs, the issue arises.

ArthurZucker commented 3 days ago

Okay, in that case let me ping @muellerzr and @SunMarc for this! πŸ€—

SunMarc commented 2 days ago

Hi @JHW5981, thanks for the report ! Could you check if running NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python test.py works ? It should use all the gpu available. If your issue only exist with 4090, I suggest you updating your driver !

JHW5981 commented 2 days ago

@SunMarc Thanks for your reply. I check NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python test.py, and found it did not work as well. Below is the error information:

Traceback (most recent call last):
  File "/home/jihuawei2/projects/AceRead/dustbin/test.py", line 42, in <module>
    main()
  File "/home/jihuawei2/projects/AceRead/dustbin/test.py", line 37, in main
    trainer.train()
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 1948, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 3328, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/trainer.py", line 3373, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/_utils.py", line 705, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 1695, in forward
    outputs = self.bert(
              ^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 1141, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 694, in forward
    layer_outputs = layer_module(
                    ^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 584, in forward
    self_attention_outputs = self.attention(
                             ^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 514, in forward
    self_outputs = self.self(
                   ^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 394, in forward
    query_layer = self.transpose_for_scores(self.query(hidden_states))
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jihuawei2/miniconda3/envs/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

  0%|          | 0/2 [00:16<?, ?it/s]  

I do think it is not the problem related to the 4090 device, because when i use πŸ€—accelarate to launch multi-gpus it works. So I think the problem maybe caused by some uncompatible behavior of accelerate in Trainer? The example code shows below:

from accelerate import Accelerator
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import torch

class SimpleDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

texts = ["Hello, world!", "Hugging Face is great!"]
labels = [0, 1]
dataset = SimpleDataset(texts, labels)
dataloader = DataLoader(dataset, batch_size=2)

accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model, dataloader = accelerator.prepare(model, dataloader)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):
    for batch in dataloader:
        optimizer.zero_grad()
        inputs = tokenizer(batch[0], padding=True, truncation=True, return_tensors="pt")
        outputs = model(**inputs, labels=batch[1])
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
    print(f"Epoch {epoch + 1} completed.")

I use accelerate launch test.py, and it works successfully(I just use twi gpu_ids here, because the other cards are used by other people):

Epoch 1 completed.
Epoch 1 completed.
Epoch 2 completed.
Epoch 2 completed.
Epoch 3 completed.
Epoch 3 completed.
SunMarc commented 2 days ago

Maybe the issue is with dataparallel ? Try to use DDP with accelerate instead by calling accelerate launch --multi-gpu myscript.py . More docs here: https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#:~:text=You%20can%20also,mixed%20precision%20disabled%3A

JHW5981 commented 2 days ago

Okay, thank you for your patient explanation. I’ll go take a look.