Open naba89 opened 8 months ago
gentle ping: @muellerzr @pacman100
Gentle ping @muellerzr @pacman100
Thanks for your patience! This has now been fixed on main, you can use pip install git+https://github.com/huggingface/transformers
until the next release
@muellerzr I am still experiencing a fork bomb with version [v4.39.3]. I was able to fix it locally by reusing the accelerator-prepared version of the dataloader self._eval_dataloader = self.accelerator.prepare(eval_dataloader)
. Are there any undesirable effects behind this? Is it possible that internally, the accelerator is creating new processes with DDP?
The fix only applies if you set pin_memory=True. (It will reuse the dataloader).
I’ll think on some undesirable side affects, and if I can’t come up with any we may switch it to always do so/reuse it
Thanks for the quick response, however I am using pin_memory=True
and still experiencing this. https://wandb.ai/butspeechfit/czech_ssl/reports/Untitled-Report--Vmlldzo3NDI5ODg1?accessToken=ygwuj69duarvjx0x2zzl81j7ru7h1v4annwjawumkou3fjydqq9a2vwaresdscfv
We can definitely go ahead and re-use the accelerate dl then (probably should even without pjn_memory for eval tbh). I'll get a PR up shortly
@Lakoc any chance you have a reproducer I can use to test with? :)
I am sorry, but we have a custom wrapper on top of the transformers, and I cannot share it right now. I can create a simple example tomorrow if that helps.
That'd be great if so, as this solution solved what was reproduced here.
I reproduced example above, just slighly enlarging dataset and setting more epochs and I am observing number of processes to be growing - https://api.wandb.ai/links/alexander-polok/nryqo6fa
import os
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import BaseModelOutput
# Dummy Dataset
class DummyDataset(Dataset):
def __init__(self, size=100):
self.size = size
self.data = torch.rand(size, 10) # Random data
self.labels = torch.randint(0, 2, (size,)) # Binary labels
def __len__(self):
return self.size
def __getitem__(self, idx):
return {'input_ids': self.data[idx], 'labels': self.labels[idx]}
@dataclass
class DummyModelOutput(BaseModelOutput):
loss: torch.Tensor = None
logits: torch.Tensor = None
# Dummy Model
class DummyModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.linear = torch.nn.Linear(10, 2)
def forward(self, input_ids, labels=None) -> DummyModelOutput:
outputs = self.linear(input_ids)
loss = F.cross_entropy(outputs, labels)
return DummyModelOutput(loss=loss, logits=outputs)
if __name__ == '__main__':
# using wandb, because it logs system metrics periodically
os.environ["WANDB_PROJECT"] = "dummy_project"
# Create dataset and model instances
dataset = DummyDataset(size=8000)
model = DummyModel()
persistent_workers = True # set to True to enable persistent workers
# Training arguments
training_args = TrainingArguments(
output_dir="./test_trainer",
run_name=f'dataloader_peristent_workers={persistent_workers}_bigger',
num_train_epochs=200,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
dataloader_num_workers=8,
dataloader_persistent_workers=persistent_workers,
logging_strategy="no",
evaluation_strategy="epoch",
)
# Initialize the custom trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=dataset,
)
# Train the model
trainer.train()
transformers 4.39.3 torch 2.1.0 accelerate 0.26.1
cc @muellerzr @SunMarc
Gentle ping @muellerzr. Although I think this might have been resolved in #29538?
I am still having this issue with transformers==4.43.3
. Everytime eval is run, it creates a new instance of the dataloaders. Memory keeps climbing up. Does not happen for train dataloders.
I use DDP, single node, 8GPU. 2 dataloader workers, persistent_workers=True, and pin_memory=True.
System Info
transformers
version: 4.36.2Who can help?
@muellerzr @pacman100
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Since the get_eval_loader is called on every evaluate call, with
dataloader_persistent_workers=True
the previous worker processes are not killed and leads to a fork-bomb and exhausts system resources and causes instability/crash.As you can see in the below plots generated with the reproduction script (in the wandb system metrics section),
Having the persistent dataloader option is good. Still, it is necessary to fix the eval loader logic, create it once, and reuse it since the eval datasets won't change in the middle of training.
This option was added in #27058 and #27189