huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.76k stars 941 forks source link

Improve `skip_first_batches` method to efficiently support `IterableDataset` and `StatefulDataloader` #2859

Closed yzhangcs closed 1 month ago

yzhangcs commented 3 months ago

Hi all, Thank you for developing this great project. Currently, the implementation naively iterates through all batches until the specified number have been consumed, which can be extremely slow for very large datasets. The latest version of the datasets library now supports resumable iterable datasets, as well as the StatefulDataloader to allow for efficient resumption of training states. a3049bf3e1246379c558fe8133c6d34e

I'm wondering if there are any plans to leverage these new features in Accelerate to make skip_first_batches more efficient and compatible with the latest datasets capabilities? If not, are there plans to add support for this in the future? Efficiently skipping batches on huge datasets would significantly speed up resuming interrupted training runs. Let me know if you need any additional information or have thoughts on the best way to approach this.

Thanks for considering this suggestion!

muellerzr commented 3 months ago

You can ping myself (@muellerzr) or @SunMarc on these things, Sylvain hasn't worked at HF for well over a year or two now :)

muellerzr commented 3 months ago

Yes, we are indeed actively looking into this!

byi8220 commented 3 months ago

Ran into something annoying while looking at this. Merely importing StatefulDataLoader (i.e. putting the line from torchdata.stateful_dataloader import StatefulDataLoader anywhere in the code) causes one of the unit test, check_seedable_sampler to fail.

I suspect it has something to do with torchdata overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:

# We should see the epoch and seed sequence [(0, 42), (1, 43), (2, 44)] twice, but the first call with seed 42 is missing
# It looks like the first sample is being drawn without setting a seed

stdout: stdout: Shuffled central dataloader passing.
stdout: stdout: {'x': tensor([-1.3022,  0.1278], device='cuda:0'), 'y': tensor([0.3097, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400,  0.7505], device='cuda:0'), 'y': tensor([0.9978, 4.5075], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.0168], device='cuda:0'), 'y': tensor([3.6974, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.9406], device='cuda:0'), 'y': tensor([1.2889, 4.9939], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -1.9510], device='cuda:0'), 'y': tensor([ 2.2716, -0.8553], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Resetting epoch and seed
stdout: stdout: Setting seed at epoch 0 42
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -0.0168], device='cuda:0'), 'y': tensor([0.9978, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.9510, -1.3022], device='cuda:0'), 'y': tensor([-0.8553,  0.3097], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.8530], device='cuda:0'), 'y': tensor([3.6974, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.3162], device='cuda:0'), 'y': tensor([4.9939, 2.2716], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: original_items:
stdout: stdout:  tensor([-1.3022,  0.1278, -1.0400,  0.7505,  0.3047, -0.0168, -0.8530,  0.9406,
stdout: stdout:         -0.3162, -1.9510, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')
stdout: stdout: new_items:
stdout: stdout:  tensor([ 0.7505,  0.1278, -1.0400, -0.0168, -1.9510, -1.3022,  0.3047, -0.8530,
stdout: stdout:          0.9406, -0.3162, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')

At the moment I don't know why this happens so I can't tell if this some misconfig in my local workspace, a bug somewhere in the torchdata library itself, or just a sharp edge that could be worked around.

Assuming there aren't other traps, writing the rest of the feature doesn't feel like too much work, though the most immediate solution I could think of (that isn't a big refactor) to just create some subclasses e.g. StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin) and let duck typing do the rest feels kinda hacky imo.

yzhangcs commented 3 months ago

@byi8220 Hi, seems that datasets does not support states with buffer right now. https://github.com/huggingface/datasets/pull/6658#issuecomment-2184828495

byi8220 commented 3 months ago

Gah, this feature is getting more complicated every second. We're also at the mercy of how StatefulDataLoader is implemented (ran into a tricky problem here) 😞

Hi, seems that datasets does not support states with buffer right now

Thank you for mentioning. Is it accurate to call this a related but separate issue? Please correct me if I'm wrong, my understanding of the problem scope is that:

  1. The datasets library is responsible for supporting state_dict/load_state_dict when a dataset has a buffer
  2. The accelerate library is responsible for utilizing state_dict/load_state_dict to save and load checkpoints (scope of this issue)
  3. The trainer library is responsible for being aware of when to call skip_first_batches

But regarding the breaking test I mentioned above, I'm unsure if it is related. The test which breaks when importing StatefulDataLoader is check_seedable_sampler. What is very strange about this test's breakage is that the test breaks without any changes to the code except by simply importing the package torchdata.stateful_dataloader. The test was unchanged and used a non-stateful pytorch DataLoader. It's as if the import itself caused something to break.

byi8220 commented 3 months ago

Also, just to elaborate on the the problem with StatefulDataLoader I'm running into, in case it's helpful info:

DataLoaderShard.__iter__() (https://github.com/huggingface/accelerate/blob/main/src/accelerate/data_loader.py#L445-L476) works by wrapping around the underlying DataLoader.iter() and advancing it. The problem here is that we eagerly pick up next_batch before yielding current_batch. This appears to be done to support self.end_of_dataloader.

Here's a crude mock of how I think this behavior works: https://pastebin.com/Sk1DfDYz

The problem here is that when the wrapper w1 is yielding 1 the inner itr has already yielded 2. In real code, w1 would be DataLoaderShard, and itr would be DataLoader.

One solution could be to implement DataLoaderShard.state_dict() to just keep the previous state_dict around. But this would introduce the overhead of calling StatefulDataLoader.state_dict() on every iteration, which might be expensive? Or maybe a refactor of the semantics of end_of_dataloader (which seems like a big refactor). Or maybe I'm missing a really obvious solution.

byi8220 commented 3 months ago

Took a shot at getting StatefulDataLoader connected to this library in https://github.com/huggingface/accelerate/pull/2895. Seems like way more work than I would have imagined, and admittedly it's experimental and there may be issues.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

byi8220 commented 2 months ago

Don't know if this is closed considering the PR is still open...