huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.22k stars 26.84k forks source link

Using Trainer + a pretrained tokenizer + 4D attention mask is extremely slow #32101

Open JackCai1206 opened 3 months ago

JackCai1206 commented 3 months ago

System Info

transformers 4.41.0

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

from transformers import LlamaForCausalLM, LlamaConfig, TrainingArguments, Trainer, AutoTokenizer
from datasets import IterableDataset
import numpy as np

model_config = LlamaConfig(
    vocab_size=10,
    hidden_size=384,
    num_hidden_layers=6,
    num_attention_heads=6,
    intermediate_size=1024,
    max_position_embeddings=1024,
)
model = LlamaForCausalLM(model_config)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')

def get_data1():
    for i in range(10000):
        yield {'input_ids': np.zeros(1024, dtype=int), 'labels': np.zeros(1024, dtype=int), 'attention_mask': np.zeros((1, 1024, 1024), dtype=float)}

def get_data2():
    for i in range(10000):
        yield {'input_ids': np.zeros(1024, dtype=int), 'labels': np.zeros(1024, dtype=int), 'attention_mask': np.zeros((1024), dtype=int)}

ds_slow = IterableDataset.from_generator(get_data1).with_format('torch')
ds_fast = IterableDataset.from_generator(get_data2).with_format('torch')

training_args = TrainingArguments(max_steps=1, output_dir='./out', report_to=None, per_device_train_batch_size=32, gradient_accumulation_steps=32)
trainer1 = Trainer(model, training_args, train_dataset=ds_slow, tokenizer=tokenizer)
trainer2 = Trainer(model, training_args, train_dataset=ds_fast, tokenizer=tokenizer)

import cProfile
cProfile.run('trainer1.train()', './test_slow.profile')
cProfile.run('trainer2.train()', './test_fast.profile')
import pstats

# compare the two profiles
p1 = pstats.Stats('./test_slow.profile')
p2 = pstats.Stats('./test_fast.profile')
p1.sort_stats('cumtime').print_stats()
         1582200 function calls (1401111 primitive calls) in 340.112 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  340.112  340.112 {built-in method builtins.exec}
        1    0.000    0.000  340.112  340.112 <string>:1(<module>)
        1    0.000    0.000  340.112  340.112 trainer.py:1784(train)
        1    0.017    0.017  340.112  340.112 trainer.py:1892(_inner_training_loop)
       33    0.001    0.000  326.171    9.884 data_loader.py:663(__iter__)
       33    0.001    0.000  325.473    9.863 data_loader.py:618(_fetch_batches)
 2486/265    0.001    0.000  325.428    1.228 {built-in method builtins.next}
       33    0.001    0.000  325.088    9.851 dataloader.py:625(__next__)
       33    0.725    0.022  325.083    9.851 dataloader.py:672(_next_data)
       33    0.002    0.000  323.988    9.818 fetch.py:24(fetch)
       33    0.000    0.000  320.979    9.727 trainer_utils.py:807(__call__)
       33    0.000    0.000  320.971    9.726 data_collator.py:270(__call__)
       33   16.982    0.515  320.971    9.726 data_collator.py:52(pad_without_fast_tokenizer_warning)
       33    0.005    0.000  303.989    9.212 tokenization_utils_base.py:3209(pad)
     6493  235.747    0.036  235.747    0.036 {built-in method torch.tensor}
      197    0.001    0.000  234.735    1.192 tokenization_utils_base.py:204(__init__)
      197    0.001    0.000  234.732    1.192 tokenization_utils_base.py:681(convert_to_tensors)
       99    0.000    0.000  234.730    2.371 tokenization_utils_base.py:718(as_tensor)
p2.sort_stats('cumtime').print_stats()
        1567440 function calls (1386340 primitive calls) in 16.431 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   16.431   16.431 {built-in method builtins.exec}
        1    0.000    0.000   16.431   16.431 <string>:1(<module>)
        1    0.000    0.000   16.431   16.431 trainer.py:1784(train)
        1    0.018    0.018   16.431   16.431 trainer.py:1892(_inner_training_loop)
       32    0.003    0.000   14.327    0.448 trainer.py:3212(training_step)
       32    0.001    0.000    8.830    0.276 accelerator.py:2093(backward)
       32    0.000    0.000    8.829    0.276 _tensor.py:433(backward)
       32    0.000    0.000    8.829    0.276 __init__.py:149(backward)
       32    8.827    0.276    8.827    0.276 {method 'run_backward' of 'torch._C._EngineBase' objects}
       33    0.000    0.000    4.546    0.138 memory.py:147(empty_cache)
       33    4.546    0.138    4.546    0.138 {built-in method torch._C._cuda_emptyCache}
 2486/265    0.001    0.000    1.469    0.006 {built-in method builtins.next}
       33    0.001    0.000    1.160    0.035 data_loader.py:663(__iter__)
       33    0.000    0.000    1.145    0.035 data_loader.py:618(_fetch_batches)
       33    0.000    0.000    1.136    0.034 dataloader.py:625(__next__)
       33    0.003    0.000    1.134    0.034 dataloader.py:672(_next_data)
       33    0.002    0.000    1.124    0.034 fetch.py:24(fetch)
       32    0.000    0.000    0.955    0.030 trainer.py:3254(compute_loss)
...
        1    0.000    0.000    0.000    0.000 modeling_utils.py:903(_
...

Expected behavior

Since the trace of the profiler is really long I only included the first few lines. I am running a small llama model on some dummy data, the only difference between the two datasets is that the slow version outputs 4D attention masks, which is a feature recently added in #27539. I am running both trainers for 1 iteration.

As you can see the slow run is 340s while the fast one runs in 16s.

The slow version of the trainer is many times slower than the fast version. The problem probably lies in the default collator DataCollatorWithPadding (when there is a pretrained tokenizer), which calls tokenizer.pad on the 4D attention masks. When you takeaway either 1) the pretrained tokenizer or 2) the 4D attention mask, trainer runs much faster.

ArthurZucker commented 3 months ago

Interesting! Would you like to open a PR? (maybe torch.pad would work better? )

jpcorb20 commented 1 month ago

Hello @JackCai1206, have you tried to use a custom collator and passing it to the trainer using the collate parameter? I also have an issue with custom 4d masks during training (using a custom collator) but my issue is related to OOM...

csking101 commented 4 days ago

Hi @ArthurZucker, I am interested in working on this issue, could I take it up?

ArthurZucker commented 1 day ago

Hey! sure, opening a PR is the way to go 🤗