huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.32k stars 872 forks source link

Add early support for `torchdata.stateful_dataloader.StatefulDataLoader` within the `Accelerator` #2895

Open byi8220 opened 6 days ago

byi8220 commented 6 days ago

What does this PR do?

Fixes https://github.com/huggingface/accelerate/issues/2859

This PR does the following:

  1. Added a new field use_stateful_dataloader in DataLoaderConfiguration. Passing this into the config makes it so that all DataLoaders prepared and returned by the Accelerator are StatefulDataLoader objects from the torchdata library
  2. Create a class DataLoaderAdapter which can wrap around and act as either PyTorch's DataLoader, or other variants of it such as StatefulDataLoader
  3. Refactor DataLoaderShard, DataLoaderDispatcher, and SkipDataLoader to inherit from DataLoaderAdapter instead of DataLoader

Testing

Added new unit tests to test that StatefulDataLoader can be dropped in and loaded and saved from.

Caveats

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@muellerzr

HuggingFaceDocBuilderDev commented 15 minutes ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.