huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.33k stars 27.09k forks source link

Trainer does not call torch.compile when torch_compile=True in TrainingArguments #34656

Open singularity-s0 opened 2 weeks ago

singularity-s0 commented 2 weeks ago

System Info

Who can help?

@muellerzr @SunMa

Information

Tasks

Reproduction

Using the following test script:

from transformers import LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments
import torch
import tempfile
import logging

device = "cuda" if torch.cuda.is_available() else "cpu"

class RepeatDataset:
    def __init__(self, x, length=64):
        self.x = x
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        return {"input_ids": self.x, "labels": self.x}

config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

def test_torch_compile_hf_trainer():
    tiny_llama = LlamaForCausalLM(config).to(device)
    with tempfile.TemporaryDirectory() as tmp_dir:
        args = TrainingArguments(
            tmp_dir,
            per_device_train_batch_size=2,
            torch_compile=True,
            max_steps=1,  # compile happens on the first step
        )
        trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset)  # noqa
        trainer.train()

def test_torch_compile():
    tiny_llama = LlamaForCausalLM(config).to(device)
    tiny_llama = torch.compile(tiny_llama, mode="max-autotune")

    input_ids = train_dataset[0]['input_ids'].unsqueeze(0).to(device)
    tiny_llama(input_ids=input_ids)

torch._logging.set_logs(dynamo=logging.INFO)

When running test_torch_compile(), there will be many lines of logs showing the compilation process of torch, and in the end, there will be a summary like this:

I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] Function                                  Runtimes (s)
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] --------------------------------------  --------------
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] _compile.compile_inner                         69.1241
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] OutputGraph.call_user_compiler                 68.5498
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] create_aot_dispatcher_function                 68.5111
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] compile_fx.<locals>.fw_compiler_base           67.0193
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] compile_fx_inner                               67.019
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] GraphLowering.run                              44.2316
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] GraphLowering.compile_to_module                18.0673
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] Scheduler.__init__                             14.038
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] CachingAutotuner.benchmark_all_configs          1.3724
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] Scheduler.codegen                               0.306
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] WrapperCodeGen.generate                         0.0055
I1108 10:10:27.721000 1968843 site-packages/torch/_dynamo/utils.py:399] cudagraphify                                    0.0001

When running test_torch_compile_hf_trainer(), however, there will be no log related to torch dynamo at all. The summary in the end will also be empty:

I1108 10:08:54.393000 1968523 site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1108 10:08:54.393000 1968523 site-packages/torch/_dynamo/utils.py:399] Function    Runtimes (s)
I1108 10:08:54.393000 1968523 site-packages/torch/_dynamo/utils.py:399] ----------  --------------

This indicates that the model is not being compiled at all.

Expected behavior

Setting torch_compile=True in TrainingArguments should make Trainer compile the model properly.

singularity-s0 commented 2 weeks ago

This seems to be related to multiple GPUs. The issue doesn't exist if only 1 GPU is set in CUDA_VISIBLE_DEVICES.

LysandreJik commented 1 week ago

This seems like a potential issue with the logs rather than with torch.compile

cc @MekkCyber, if you have the banwidth, could you take a look at this?

MekkCyber commented 4 days ago

hi @singularity-s0, you just need to launch the script using accelerate launch script.py because it's a multi gpu setting

singularity-s0 commented 3 days ago

OK Thanks. Would you be kind enough to point out where torch.compile is called in the code? So that I can better analyze the logic.