Closed ri938 closed 1 week ago
looks like the issue is that torch.manual seed is used by both nn.Dropout and by the data loader.
it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler
since we can only influence Dropout via torch.manual_seed, can I implement a change so that data_seed is used in order to seed the RandomSamper? or will this be not backwards compatible in which case I can add a new argument to do this and deprecate random seed?
Hi @ri938
thanks for this interesting issue, I am not really familiar with the way accelerate
sets the seed for the data sampler. I am also not sure how to do you set both the seed for dropout and the sampler in your code, could you share more details about that ?
So I set the seed on startup to the same value "100" on each device
def set_training_seed(seed):
from transformers import set_seed
set_seed(seed)
this ensures that each devices has the same init of weights before training starts
then I set the seed in the TrainerArguments which gets passed to the Trainer to a constant value "100" too
trainer_args = TrainingArguments(seed=100, **kwargs)`
the seed
torch.manual_seed(x)
is what impacts dropout. It also impacts the RandomSampler.
And therefore there is no way to ensure that dropout masks vary across devices without also breaking the data ordering on each device which requires the same seed to be set.
I would argue this is potentially an issue impacting many training runs for many users. Therefore there should be both a way to avoid this issue and also a warning message or error to prevent people training unaware of it.
it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler
Hi @ri938 you are right, the class variable
data_seed
is not used andset_seed
is used for both data sampling and training. Please refer the discussion in the #31255 issue
Yes, I was suggesting that if we used data_seed for the data sampling then this could be used to fix this issue. But this would break backwards compatibility.
Here is another image to illustrate the problem. When training gpt2 the gradient norms are huge when you use the same seed for each device. But when you vary the seed for each device its more sensible.
This is the workaround I am using to fix this issue
I am adding a callback
class SeedDeviceRandomlyCallback(TrainerCallback):
def on_train_begin(self, args, state, control, **kwargs):
global_rank = int(os.environ['RANK'])
new_seed = args.seed + global_rank
print('Setting torch seed to {} on device {}'.format(new_seed, global_rank))
torch.manual_seed(new_seed)
Because you have to set the seed to be different after get_train_dataloader has been called in order to not break data ordering.
After applying just this one callback. This is a demonstration of how much it improved performance
Would be nice to have this merged then!
@ri938 do you want to open a PR with your proposed changes?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
GPT2
torch==2.3.1
DDP
using transformers Trainer 4.41.2
Who can help?
@muellerzr @SunMarc (Trainer code)
@ArthurZucker and @younesbelkada (text models)
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
On 1 GPU I get
On 4 GPU I get the same dropout masked being applied
with 1 GPU the data is unique
with 4 GPUs you get duplicate data across devices
Expected behavior
Should be a way to set the random seed to control dropout without destroying the data ordering when doing DDP.
I am happy to submit a MR to fix this issue if given some pointers about how to implement it.