huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.95k stars 967 forks source link

Using StatefulDataLoader when num_workers>1 doesn't work #3080

Closed kmchiti closed 1 month ago

kmchiti commented 2 months ago

System Info

Accelerate version: 0.34.0
Platform: Linux 5.15.0-101-generic #111-Ubuntu SMP x86_64 GNU/Linux
Python version: 3.10.14
PyTorch version (GPU?): 2.1.2+cu118 True
System RAM: 494 GB
GPU type: 4x NVIDIA A100

Information

Tasks

Reproduction

  1. Create an IterableDataset.
  2. Prepare a toy model (e.g., GPT-2 with a limited vocabulary size and sequence length)
  3. Use a DataLoader with use_stateful_dataloader=True in Accelerate.
  4. Run the training script on a single GPU (works fine).
  5. Run the training script on multiple GPUs using torchrun --nproc_per_node=4 (error occurs).

Code Example:

from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling
from datasets import Dataset
from accelerate import Accelerator
from torch.utils.data import DataLoader
from accelerate.utils import DataLoaderConfiguration
import torch

if __name__ == "__main__":

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    ds = Dataset.from_dict({"input_ids": torch.randint(0,10, (4096,64))}).to_iterable_dataset(num_shards=5)
    dataloader_params = {
            "batch_size": 16,
            "collate_fn": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
            "num_workers": 4,
            "pin_memory": True,
        }
    data_loader = DataLoader(ds, **dataloader_params)

    model = GPT2LMHeadModel(GPT2Config(vocab_size=10, n_positions=64))
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    accelerator = Accelerator(dataloader_config=DataLoaderConfiguration(use_stateful_dataloader=True))
    model, data_loader, optimizer = accelerator.prepare(model, data_loader, optimizer)

    for step, batch in enumerate(data_loader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

Error:

When running the script on multiple GPUs, I encountered the following error:

Traceback (most recent call last):
  File "/path/to/your/script/dataloader_test.py", line 31, in <module>
    for step, batch in enumerate(data_loader):
  File "/path/to/accelerate/data_loader.py", line 796, in __iter__
    next_batch, next_batch_info = self._fetch_batches(main_iterator)
  File "/path/to/accelerate/data_loader.py", line 749, in _fetch_batches
    self._update_state_dict()
  File "/path/to/accelerate/data_loader.py", line 486, in _update_state_dict
    self.adjust_state_dict_for_prefetch()
  File "/path/to/accelerate/data_loader.py", line 466, in adjust_state_dict_for_prefetch
    if self.dl_state_dict["_sampler_iter_yielded"] > 0:
KeyError: '_sampler_iter_yielded'

Expected behavior

The script should run without errors on multiple GPUs, similar to the behavior observed on a single GPU and without multiprocessing. The issue seems related to accessing _sampler_iter_yielded from dl_state_dict. Given that dl_state_dict contains _sampler_iter_yielded under _main_snapshot and not directly as an attribute, the code should correctly navigate the dictionary structure to access this value.

muellerzr commented 2 months ago

To clarify the problem more, the issue is that we currently don't support when the number of workers > 1.

(There is also some issues with datasets.Dataset in general)

muellerzr commented 2 months ago

For right now this isn't supported/implemented yet

muellerzr commented 2 months ago

We'll need a different and custom adjust_state_dict_for_prefetch func, which we do support working with if you want to try playing with it in the interim.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.