huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.32k stars 26.35k forks source link

Trainer may stop short of requested number of epochs when using gradient_accumulation_steps > 1 #33455

Open tomtseng opened 1 week ago

tomtseng commented 1 week ago

System Info

Who can help?

@muellerzr @SunMarc (Trainer)

Information

Tasks

Reproduction

Run the following code. I train for 3 epochs with a batch size of 2, gradient accumulation steps of 2, and a training dataset of size 9 (larger sizes like 101 also reproduce the issue, so this is not just an edge case with tiny dataset sizes).

from typing_extensions import override

from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
from datasets import Dataset

class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.examples_seen = 0

    @override
    def compute_loss(self, model, inputs, return_outputs=False):
        self.examples_seen += inputs["input_ids"].shape[0]
        print(
            f"batch size={inputs['input_ids'].shape[0]},"
            f" examples seen={self.examples_seen}"
        )
        return super().compute_loss(model, inputs, return_outputs=return_outputs)

DATASET_SIZE = 9
dataset = Dataset.from_dict({
    "input_ids": [[0] for _ in range(DATASET_SIZE)],
    "labels": [0] * DATASET_SIZE,
})

model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m")
model.config.pad_token_id = model.config.eos_token_id

trainer = MyTrainer(
    model=model,
    args=TrainingArguments(
        output_dir="/tmp/results",
        num_train_epochs=3,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        save_strategy="no",
    ),
    train_dataset=dataset,
)

trainer.train()
print("Examples seen:", trainer.examples_seen)
print("Epochs x training set size:", trainer.args.num_train_epochs * DATASET_SIZE)

The final print-out is:

{'train_runtime': 2.0071, <... entries omitted for brevity ...>, 'epoch': 2.4}
Examples seen: 22
Epochs x training set size: 27

Here we see that the trainer actually did 2.4 epochs rather than the 3 specified in TrainingArguments.

In contrast if I double the per_device_train_batch_size to 4 and halve the gradient_accumulation_steps to 1, which maintains the same effective batch size, then we get 3 epochs like I expect.

Expected behavior

I expect the trainer to train for 3 epochs exactly rather than producing partial epochs. In general when the effective batch size is the same across two training runs then I expect them to be basically identical up to numerical precision issues.

However it's possible I am misunderstanding the expected behavior of gradient_accumulation_steps, in which case I would appreciate a pointer to what setting (if any) I can change to guarantee full epochs.

tomtseng commented 1 week ago

Hmm I think this is a duplicate of #31677 but I hope the minimal reproduction is helpful.