Open rwightman opened 2 years ago
FYI, item 2 essentially means that all training ends up as ResampledShards() as the distributed worker all get seeded differently (I confirmed this with a test)
For (1), yes, to make exact shuffling across nodes work, you either need to be very careful in how you set up your epochs or you need to use some external synchronization. I've used Redis in the past for that purpose. We haven't invested much time in it since torchdata will hopefully provide new mechanisms to deal with this case.
For (2), you can use the PYTHONHASHSEED environment variable. But I'll change the code so that won't be necessary anymore.
I've been trying to debug and resolve a number of distributed training shuffle issues recently, I've found some alarming issues...
There is no way to have a reliable epoch count based deterministic seed (ie detshuffle) without sharing an epoch counter across train loop and the pipeline in a process safe fashion (ie mp.Value). That is across all combinations of dataloader
persistent_worker=True/False
, etc. Especially when one useswith_epoch
and the wrap may happen at different times (wrap increments epoch) if sample counts aren't perfectly balanced across shards and shards aren't perfectly distributed across processes. Some workers might rollover and increment their epochs while others might stay on the old value, next epoch, all of they are out of sync.random.Random() seeding is being done with tuple, this uses a hash as the seed and that hash can differ across python runtimes (interpreter invocations, and thus machines in a distributed process). This is a pretty big one.