huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133k stars 26.54k forks source link

`dataloader_persistent_workers=True` causes fork-bomb due to repeated creation of `eval_dataloader` #28469

Open naba89 opened 8 months ago

naba89 commented 8 months ago

System Info

Who can help?

@muellerzr @pacman100

Information

Tasks

Reproduction

import os
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import BaseModelOutput

# Dummy Dataset
class DummyDataset(Dataset):
    def __init__(self, size=100):
        self.size = size
        self.data = torch.rand(size, 10)  # Random data
        self.labels = torch.randint(0, 2, (size,))  # Binary labels

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {'input_ids': self.data[idx], 'labels': self.labels[idx]}

@dataclass
class DummyModelOutput(BaseModelOutput):
    loss: torch.Tensor = None
    logits: torch.Tensor = None

# Dummy Model
class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 2)

    def forward(self, input_ids, labels=None) -> DummyModelOutput:
        outputs = self.linear(input_ids)
        loss = F.cross_entropy(outputs, labels)
        return DummyModelOutput(loss=loss, logits=outputs)

if __name__ == '__main__':

    # using wandb, because it logs system metrics periodically
    os.environ["WANDB_PROJECT"] = "dummy_project"

    # Create dataset and model instances
    dataset = DummyDataset(size=1000)
    model = DummyModel()

    persistent_workers = False    # set to True to enable persistent workers

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./test_trainer",
        run_name=f'dataloader_peristent_workers={persistent_workers}',
        num_train_epochs=20,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        dataloader_num_workers=8,
        dataloader_persistent_workers=persistent_workers,
        logging_strategy="no",
        evaluation_strategy="epoch",
    )

    # Initialize the custom trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
    )

    # Train the model
    trainer.train()

Expected behavior

Since the get_eval_loader is called on every evaluate call, with dataloader_persistent_workers=True the previous worker processes are not killed and leads to a fork-bomb and exhausts system resources and causes instability/crash.

As you can see in the below plots generated with the reproduction script (in the wandb system metrics section),

image

Having the persistent dataloader option is good. Still, it is necessary to fix the eval loader logic, create it once, and reuse it since the eval datasets won't change in the middle of training.

This option was added in #27058 and #27189

naba89 commented 7 months ago

gentle ping: @muellerzr @pacman100

amyeroberts commented 7 months ago

Gentle ping @muellerzr @pacman100

muellerzr commented 7 months ago

Thanks for your patience! This has now been fixed on main, you can use pip install git+https://github.com/huggingface/transformers until the next release

Lakoc commented 6 months ago

@muellerzr I am still experiencing a fork bomb with version [v4.39.3]. I was able to fix it locally by reusing the accelerator-prepared version of the dataloader self._eval_dataloader = self.accelerator.prepare(eval_dataloader). Are there any undesirable effects behind this? Is it possible that internally, the accelerator is creating new processes with DDP?

muellerzr commented 6 months ago

The fix only applies if you set pin_memory=True. (It will reuse the dataloader).

I’ll think on some undesirable side affects, and if I can’t come up with any we may switch it to always do so/reuse it

Lakoc commented 6 months ago

Thanks for the quick response, however I am using pin_memory=True and still experiencing this. https://wandb.ai/butspeechfit/czech_ssl/reports/Untitled-Report--Vmlldzo3NDI5ODg1?accessToken=ygwuj69duarvjx0x2zzl81j7ru7h1v4annwjawumkou3fjydqq9a2vwaresdscfv

muellerzr commented 6 months ago

We can definitely go ahead and re-use the accelerate dl then (probably should even without pjn_memory for eval tbh). I'll get a PR up shortly

muellerzr commented 6 months ago

@Lakoc any chance you have a reproducer I can use to test with? :)

Lakoc commented 6 months ago

I am sorry, but we have a custom wrapper on top of the transformers, and I cannot share it right now. I can create a simple example tomorrow if that helps.

muellerzr commented 6 months ago

That'd be great if so, as this solution solved what was reproduced here.

Lakoc commented 6 months ago

I reproduced example above, just slighly enlarging dataset and setting more epochs and I am observing number of processes to be growing - https://api.wandb.ai/links/alexander-polok/nryqo6fa

import os
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import BaseModelOutput

# Dummy Dataset
class DummyDataset(Dataset):
    def __init__(self, size=100):
        self.size = size
        self.data = torch.rand(size, 10)  # Random data
        self.labels = torch.randint(0, 2, (size,))  # Binary labels

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {'input_ids': self.data[idx], 'labels': self.labels[idx]}

@dataclass
class DummyModelOutput(BaseModelOutput):
    loss: torch.Tensor = None
    logits: torch.Tensor = None

# Dummy Model
class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 2)

    def forward(self, input_ids, labels=None) -> DummyModelOutput:
        outputs = self.linear(input_ids)
        loss = F.cross_entropy(outputs, labels)
        return DummyModelOutput(loss=loss, logits=outputs)

if __name__ == '__main__':
    # using wandb, because it logs system metrics periodically
    os.environ["WANDB_PROJECT"] = "dummy_project"

    # Create dataset and model instances
    dataset = DummyDataset(size=8000)
    model = DummyModel()

    persistent_workers = True  # set to True to enable persistent workers

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./test_trainer",
        run_name=f'dataloader_peristent_workers={persistent_workers}_bigger',
        num_train_epochs=200,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        dataloader_num_workers=8,
        dataloader_persistent_workers=persistent_workers,
        logging_strategy="no",
        evaluation_strategy="epoch",
    )

    # Initialize the custom trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
    )

    # Train the model
    trainer.train()

transformers 4.39.3 torch 2.1.0 accelerate 0.26.1

amyeroberts commented 3 months ago

cc @muellerzr @SunMarc

amyeroberts commented 1 month ago

Gentle ping @muellerzr. Although I think this might have been resolved in #29538?

tasansal commented 1 month ago

I am still having this issue with transformers==4.43.3. Everytime eval is run, it creates a new instance of the dataloaders. Memory keeps climbing up. Does not happen for train dataloders.

I use DDP, single node, 8GPU. 2 dataloader workers, persistent_workers=True, and pin_memory=True.