huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.87k stars 27.2k forks source link

CausalLM loss function throws runtime error in multi-gpu setup #35086

Open xspirus opened 13 hours ago

xspirus commented 13 hours ago

System Info

Who can help?

@ArthurZucker @muellerzr

While trying to train Qwen/Qwen2.5-7B-Instruct or meta-llama/Llama-3.1-8B-Instruct using the SFTTrainer of the trl library, on a machine with 4 L4 GPUs, during the forward pass, when the loss is about to be calculated, the following error occurs:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/xspirus/sample-project/sample.py", line 179, in <module>
    sft()
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/src/llms/sft/__main__.py", line 168, in sft
    trainer.train()
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3633, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1183, in forward
    loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 46, in ForCausalLMLoss
    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 28, in fixed_cross_entropy
    loss = loss / num_items_in_batch
           ~~~~~^~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

This stems from the changes in version 4.46.0 where https://github.com/huggingface/transformers/pull/34191 was introduced. What I suspect is going on here, is that because num_items_in_batch is calculated based on the batch sampler of the inputs (and the inputs are probably placed on cuda:0) and the loss which is calculated based on the outputs of the model (which are placed in the last GPU cuda:3), thus creating the error.

I am not sure if the fix is as simple as the following piece of code:

def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch.to(loss.device)
    return loss

where the above fix lies in the num_items_in_batch.to(loss.device).

I am creating this issue so that you can more confidently solve this issue, since you are more familiar with this part of the codebase.

NOTE: the error does not occur on transformers v4.45.2.

Information

Tasks

Reproduction

import tempfile
from pathlib import Path
from typing import cast

import click
import torch
from datasets import load_dataset
from peft.mapping import get_peft_model
from peft.tuners.lora import LoraConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizerFast
from transformers.training_args import IntervalStrategy
from trl import SFTConfig
from trl import SFTTrainer

@click.command(name="sft", help="Train LLM with SFT technique.")
@click.option("--model", "model_name", type=str, required=True, help="Base model name.")
@click.option(
    "--output",
    type=click.Path(exists=False, dir_okay=True, writable=True, resolve_path=True, path_type=Path),
    required=True,
    help="Output directory.",
)
@click.option(
    "--use-lora",
    is_flag=True,
    help="Whether to use LoRA or not.",
)
def sft(model_name: str, output: Path, use_lora: bool = False):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        raise ValueError("Tokenizer must be fast tokenizer.")
    if not tokenizer.pad_token:
        if tokenizer.unk_token:
            tokenizer.pad_token = tokenizer.unk_token
        else:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "right"

    data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
    data = data.shuffle()

    model = cast(
        PreTrainedModel,
        AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto",
            attn_implementation="flash_attention_2",
            trust_remote_code=True,
            use_cache=False,
        ),
    )

    if use_lora:
        lora_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=32,
            lora_alpha=64,
            lora_dropout=0.05,
            target_modules="all-linear",
        )
        model = get_peft_model(model, lora_config)

    with tempfile.TemporaryDirectory() as tmp_dir:
        training_args = SFTConfig(
            output_dir=tmp_dir,
            bf16=model.dtype == torch.bfloat16,
            fp16=model.dtype == torch.float16,
            learning_rate=2e-4,
            neftune_noise_alpha=5,
            num_train_epochs=3,
            packing=False,
            per_device_eval_batch_size=4,
            per_device_train_batch_size=4,
            save_strategy=IntervalStrategy.NO,
            warmup_ratio=0.03,
            weight_decay=1e-3,
        )

        trainer = SFTTrainer(
            model=model,
            processing_class=tokenizer,
            args=training_args,
            train_dataset=data,
        )
        trainer.train()

    tokenizer.padding_side = original_padding_side
    tokenizer.save_pretrained(str(output))
    model.save_pretrained(str(output))

if __name__ == "__main__":
    sft()

and requirements

accelerate==1.1.1
aiohappyeyeballs==2.4.4
aiohttp==3.11.9
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
attrs==24.2.0
bitsandbytes==0.44.1
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
colorama==0.4.6 ; platform_system == 'Windows'
datasets==3.1.0
dill==0.3.8
distro==1.9.0
einops==0.8.0
filelock==3.16.1
flash-attn==2.7.0.post2
frozenlist==1.5.0
fsspec==2024.9.0
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
h11==0.14.0
httpcore==1.0.7
httpx==0.28.0
huggingface-hub==0.26.3
idna==3.10
jinja2==3.1.4
jiter==0.8.0
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
langchain==0.3.9
langchain-core==0.3.21
langchain-openai==0.2.11
langchain-text-splitters==0.3.2
langsmith==0.1.147
markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
openai==1.56.2
orjson==3.10.12 ; platform_python_implementation != 'PyPy'
packaging==24.2
pandas==2.2.3
peft==0.13.2
propcache==0.2.1
psutil==6.1.0
pyarrow==18.1.0
pydantic==2.10.3
pydantic-core==2.27.1
pygments==2.18.0
python-dateutil==2.9.0.post0
pytz==2024.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rich==13.9.4
ruff==0.8.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
setuptools==75.6.0
six==1.16.0
sniffio==1.3.1
sqlalchemy==2.0.36
sympy==1.13.1
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.3
torch==2.5.1
tqdm==4.67.1
transformers==4.46.3
triton==3.1.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
trl==0.12.1
typing-extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.18.3

Expected behavior

Training should occur normally like in v.4.45.2.

Rocketknight1 commented 13 hours ago

cc @muellerzr @SunMarc

techkang commented 1 hour ago

I think change https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/trainer.py#L5060-L5061 to

        if self.args.average_tokens_across_devices:
            num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
        if num_items_in_batch is not None:
            num_items_in_batch = num_items_in_batch.item()

may be a eaiser way to fix this bug.