huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.94k stars 967 forks source link

Importing `torchdata.stateful_dataloader` causes the test `check_seedable_sampler` to fail #2894

Closed byi8220 closed 4 months ago

byi8220 commented 4 months ago

Merely importing StatefulDataLoader from the nightly torchdata package (i.e. putting the line from torchdata.stateful_dataloader import StatefulDataLoader anywhere in the code) causes one of the unit test, check_seedable_sampler to fail.

Stack trace obtained by running tests with the import

stderr: Traceback (most recent call last):
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 827, in <module>
stderr:     main()
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 802, in main
stderr:     check_seedable_sampler()
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 381, in check_seedable_sampler
stderr:     assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
stderr:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
stderr: AssertionError: Did not obtain the same items with the same seed and epoch.

I suspect it has something to do with torchdata overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:

# We should see the epoch and seed sequence [(0, 42), (1, 43), (2, 44)] twice, but the first call with seed 42 is missing
# It looks like the first sample is being drawn without setting a seed

stdout: stdout: Shuffled central dataloader passing.
stdout: stdout: {'x': tensor([-1.3022,  0.1278], device='cuda:0'), 'y': tensor([0.3097, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400,  0.7505], device='cuda:0'), 'y': tensor([0.9978, 4.5075], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.0168], device='cuda:0'), 'y': tensor([3.6974, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.9406], device='cuda:0'), 'y': tensor([1.2889, 4.9939], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -1.9510], device='cuda:0'), 'y': tensor([ 2.2716, -0.8553], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Resetting epoch and seed
stdout: stdout: Setting seed at epoch 0 42
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -0.0168], device='cuda:0'), 'y': tensor([0.9978, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.9510, -1.3022], device='cuda:0'), 'y': tensor([-0.8553,  0.3097], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.8530], device='cuda:0'), 'y': tensor([3.6974, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.3162], device='cuda:0'), 'y': tensor([4.9939, 2.2716], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: original_items:
stdout: stdout:  tensor([-1.3022,  0.1278, -1.0400,  0.7505,  0.3047, -0.0168, -0.8530,  0.9406,
stdout: stdout:         -0.3162, -1.9510, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')
stdout: stdout: new_items:
stdout: stdout:  tensor([ 0.7505,  0.1278, -1.0400, -0.0168, -1.9510, -1.3022,  0.3047, -0.8530,
stdout: stdout:          0.9406, -0.3162, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')

How to reproduce:

  1. Install the torchdata nightly:
    pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
  2. Import stateful_dataloader in the test check_seedable_sampler:
    def check_seedable_sampler():
    +   import torchdata.stateful_dataloader
    # Set seed
    set_seed(42)
    train_set = RegressionDataset(length=10, seed=42)
    train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
    ...
  3. Run make test and expect 1 failing test
byi8220 commented 4 months ago

Alright, so I'm pretty convinced that these lines are the culprit:

  1. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L61-L62
  2. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L134-L135

Multiple attempts at looking into this show inconsistency between which sampler implementation to use, and by simply just importing stateful_dataloader before anything else, I have managed to get this test to work.

If this is the case, I can think of 3 possible solutions.

  1. Make sure that we always import this first. This seems very fragile.
  2. Somehow resolve this issue through code. Unless someone better at python than me has seen this problem before, this seems like an absolute nightmare of a code change.
  3. File an issue with the torchdata maintainers. I'm not sure if this is their problem to fix but the fact that their imports redefine fundamental pytorch types such as RandomSampler and BatchSampler is dubious to me.
muellerzr commented 4 months ago

Ah the joys of monkey-patching 😓

  1. this is a them issue IMO. At most we can guard imports when we officially support their dataloaders, but Accelerate is designed to work with native PyTorch dataloaders. This monkey-patching approach needs to be carefully guarded on such a core feature.
muellerzr commented 4 months ago

I've pinged the torchdata team internally, we'll come to some solution :)

byi8220 commented 4 months ago

Thanks!

Well, I can't really recommend merging in https://github.com/huggingface/accelerate/pull/2895 even if it looks good given this situation then.

muellerzr commented 4 months ago

No worries, thank you for working on the initial support! Based on what happens next we can move the discussion to that PR on how to go forward. (And thank you so much for working on doing that!)

muellerzr commented 4 months ago

I'll give that PR a review in the AM when I can look thoroughly at what you've done (great work BTW)

byi8220 commented 4 months ago

SG, Thanks. It passes the test cases that I've written, but admittedly the implementation is rather hacky and I am only able to test on my local machine (with a single rtx 3060ti)

I kinda was just trying to write the bare minimum needed code to get it working, and even that turned out to be incredibly invasive.

muellerzr commented 4 months ago

A good solution has been found 🤗

https://github.com/pytorch/data/pull/1281

byi8220 commented 4 months ago

This should be good to close since tests are now passing after upgrading torchdata to their nightly 0.7.1.dev20240703+cpu