Open xspirus opened 13 hours ago
cc @muellerzr @SunMarc
I think change https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/trainer.py#L5060-L5061 to
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
if num_items_in_batch is not None:
num_items_in_batch = num_items_in_batch.item()
may be a eaiser way to fix this bug.
System Info
transformers
version: 4.46.3device_map="auto"
)Who can help?
@ArthurZucker @muellerzr
While trying to train
Qwen/Qwen2.5-7B-Instruct
ormeta-llama/Llama-3.1-8B-Instruct
using theSFTTrainer
of thetrl
library, on a machine with 4 L4 GPUs, during the forward pass, when the loss is about to be calculated, the following error occurs:This stems from the changes in version
4.46.0
where https://github.com/huggingface/transformers/pull/34191 was introduced. What I suspect is going on here, is that becausenum_items_in_batch
is calculated based on the batch sampler of the inputs (and the inputs are probably placed oncuda:0
) and theloss
which is calculated based on the outputs of the model (which are placed in the last GPUcuda:3
), thus creating the error.I am not sure if the fix is as simple as the following piece of code:
where the above fix lies in the
num_items_in_batch.to(loss.device)
.I am creating this issue so that you can more confidently solve this issue, since you are more familiar with this part of the codebase.
NOTE: the error does not occur on
transformers v4.45.2
.Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
and requirements
Expected behavior
Training should occur normally like in v.4.45.2.