Closed byi8220 closed 4 months ago
Alright, so I'm pretty convinced that these lines are the culprit:
Multiple attempts at looking into this show inconsistency between which sampler implementation to use, and by simply just importing stateful_dataloader
before anything else, I have managed to get this test to work.
If this is the case, I can think of 3 possible solutions.
torchdata
maintainers. I'm not sure if this is their problem to fix but the fact that their imports redefine fundamental pytorch types such as RandomSampler and BatchSampler is dubious to me.Ah the joys of monkey-patching 😓
I've pinged the torchdata team internally, we'll come to some solution :)
Thanks!
Well, I can't really recommend merging in https://github.com/huggingface/accelerate/pull/2895 even if it looks good given this situation then.
No worries, thank you for working on the initial support! Based on what happens next we can move the discussion to that PR on how to go forward. (And thank you so much for working on doing that!)
I'll give that PR a review in the AM when I can look thoroughly at what you've done (great work BTW)
SG, Thanks. It passes the test cases that I've written, but admittedly the implementation is rather hacky and I am only able to test on my local machine (with a single rtx 3060ti)
I kinda was just trying to write the bare minimum needed code to get it working, and even that turned out to be incredibly invasive.
A good solution has been found 🤗
This should be good to close since tests are now passing after upgrading torchdata
to their nightly 0.7.1.dev20240703+cpu
Merely importing StatefulDataLoader from the nightly
torchdata
package (i.e. putting the linefrom torchdata.stateful_dataloader import StatefulDataLoader
anywhere in the code) causes one of the unit test,check_seedable_sampler
to fail.Stack trace obtained by running tests with the import
I suspect it has something to do with
torchdata
overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:How to reproduce:
make test
and expect 1 failing test