facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.2k stars 281 forks source link

[ShardedDDP] Handle transition to eval + parameter change #586

Closed blefaudeux closed 3 years ago

blefaudeux commented 3 years ago

šŸ› Bug

Reported by @SeanNaren, after a successful training the switch to eval() is not properly taken into account, and the grads are marked as "waiting to be reduced" while they should not (we're in eval..)

visionscaper commented 3 years ago

Iā€™m trying to understand what effect this bug could have on training loss, if any, when one switches to and from eval() between training iterations. Could you shed light on this @blefaudeux? I recently tried ShardedDDP and observed degraded training performance. I also know that I switch to eval mode between iterations to calculate validation set loss. Hence the question.

blefaudeux commented 3 years ago

oh, this materialized (as Sean saw it) as an assert firing (here), so it should not have gone unnoticed. If this assert was not fired then it should have taken network bandwidth (and thus speed) if you're calling .backward() somewhere, nothing more, so it should not have been anything super serious unless I'm missing something.

Could you elaborate on the degraded training performance ? you should get exactly the same results as DDP when the settings are the same, anything different is a bug, I would like to know more

blefaudeux commented 3 years ago

to elaborate, what can change is if you switch on ShardedDDP, see that you have some free memory and bump up the batch size & LR when compared to DDP. Depending on your training, you can definitely arrive to a LR/batch combination which degrades the accuracy, but it's not a bug in that case, this has to do with the dynamics of training (and the optimizers having a hard job :)). AdaScale can help you with that.

If this is something else (same settings but worse accuracy than DDP) then something is definitely wrong, please help me root it out !

visionscaper commented 3 years ago

Hi @blefaudeux, thanks for responding so swiftly. No, this is not because I changed the batch size/LR settings; the degradation occurred under the same conditions. I need to make some time, but if I do, I will publish some results and info here.

visionscaper commented 3 years ago

Hi @blefaudeux,

Please find my experiments results here:

Model FP16? Use ShardedDDP (No/FP32/FP16) Tr. loss @ 200 iter Val. loss @ 200 iter Memory usage per GPU. Avg. batch training time. Model FP16? Use ShardedDDP (No/FP32/FP16) Tr. loss @ 200 iter Val. loss @ 200 iter Memory usage per GPU. Avg. batch training time.
FP32 No 5.012 5.126 ~9.5GB 1341ms
FP32 FP32 5.362 5.513 ~8.8GB 987ms
FP16 No 5.002 5.139 ~7.6GB 1163ms
FP16 FP16 5.358 5.513 ~7.1GB 1139ms
FP16 FP16 reduce_buffer_size=0 5.358 5.512 ~7.0GB 1127ms
FP16 FP16 zero_grad() fix 5.372 5.523 ~7.1GB 1146ms

As you can see, when using ShardedDDP, the loss is significantly higher after 200 batch iterations, both when using mixed precision and without mixed precision.

Details about the model architecture that might be relevant:

Edit 1) : Further, I set the random seed = 0. E.g. torch.manual_seed(seed) Edit 2) : Added results of experiment with reduce_buffer_size=0 flag for ShardedDDP Edit 3) : Added results of experiment with optimizer.zero_grad() changed to model.zero_grad()

blefaudeux commented 3 years ago

thanks a lot @visionscaper, there's definitely something wrong, with the default flags you should get exactly the same as DDP ! Could you specify what flags were used for ShardedDDP and OSS, and what version of fairscale ? In that config (single node) you can turn the buckets off for shardedDDP (reduce_bucket_size = 0), but it should not impact the accuracy (you'll get a bit more speed and space)

blefaudeux commented 3 years ago

My guess is that there's something which is not properly handled outside of the core model, maybe the embedding ? Is it trainable ?

visionscaper commented 3 years ago

OSS was instantiated as:

base_optimizer_arguments = {'lr': lr}
...
OSS(params=params, optim=Optimizer, **base_optimizer_arguments)

Where Optimizer is Adam in this case

ShardedDDP was instantiated as:

training_model = ShardedDDP(training_model, list(optimizers.values()))

Where optimizers is a dict with, in this case, the single Adam optimizer, wrapped by OSS.

So, no special flags applied.

visionscaper commented 3 years ago

Version of Fairscale:

$ pip show fairscale
Name: fairscale
Version: 0.3.2
Summary: FairScale: A PyTorch library for large-scale and high-performance training.
Home-page: UNKNOWN
Author: Facebook AI Research
Author-email: todo@fb.com
License: UNKNOWN
Location: /home/freddy/.virtualenvs/nuhame/lib/python3.7/site-packages
Requires: torch
Required-by: 
visionscaper commented 3 years ago

I wouldn't know why the embedding layer would not be trainable in the case of applying ShardedDDP and trainable otherwise.

There is one thing to note though, when using conventional DDP, I train with flag find_unused_parameters=False. If I remember correctly, I set this flag because otherwise I would get warnings concerning unused parameters. Could this have an effect when applying ShardedDDP?

visionscaper commented 3 years ago

@blefaudeux Did you mean reduce_buffer_size instead of reduce_bucket_size? I don't see a reduce_bucket_size flag for ShardedDDP.

blefaudeux commented 3 years ago

I wouldn't know why the embedding layer would not be trainable in the case of applying ShardedDDP and trainable otherwise.

There is one thing to note though, when using conventional DDP, I train with flag find_unused_parameters=False. If I remember correctly, I set this flag because otherwise I would get warnings concerning unused parameters. Could this have an effect when applying ShardedDDP?

it was an open question, are these embeddings trainable ? ShardedDDP would not change that, it's to get a better idea of the inputs. I'm just wondering whether there's something which could be missed when ShardedDDP needs to orchestrate moving the gradients around

visionscaper commented 3 years ago

Ah, ok, so yes, this is a trainable embedding layer, basically it is torch.nn.Embedding.

visionscaper commented 3 years ago

@blefaudeux FYI. I added the results for the experiment with the reduce_buffer_size=0 flag for ShardedDDP to the table above.

blefaudeux commented 3 years ago

@blefaudeux FYI. I added the results for the experiment with the reduce_buffer_size=0 flag for ShardedDDP to the table above.

thanks, they make sense, same results but a bit more memory and a bit faster without the reduce buffer. edit: actually the fp32 result is a bit strange, it's faster than fp16 ? I'm guessing that 1080 could be limited on that front

blefaudeux commented 3 years ago

Ah, ok, so yes, this is a trainable embedding layer, basically it is torch.nn.Embedding.

thanks, I'll check that, my guess is that ShardedDDP is not picking that up so this layer could be trained independently instead of seeing a bigger batch. I'll get back to you asap, trying to repro

edit: this (would be a bit strange to be honest, if it's a normal parameter then it should be picked) or a GRU specificity, I've never tried these

visionscaper commented 3 years ago

@blefaudeux ok, if I can be of any help just let me know. I think this is a great project, I especially love the "Pytorch-y" way you guys have modularised the functionality; IMO better than Deepspeed where they break the Pytorch usage model with the configuration file.

visionscaper commented 3 years ago

actually the fp32 result is a bit strange, it's faster than fp16 ? I'm guessing that 1080 could be limited on that front

Yes, I also noticed that, it could indeed be a constraint on the 1080 Ti. If you like I can do an experiment on a AWS machine with Nvidia T4's.

visionscaper commented 3 years ago

@blefaudeux One last note, I always was under the impression that this could somehow be related to the fact that the decoder module is called many times in a loop, because of the auto regressive nature of creating the sequential model output; when using "modern" transformer architectures you usually don't have to apply the model auto regressively during training.

Could this be something you have not tested? Or maybe the error wasn't perceivable because a layer, like an embedding layer, is maybe called twice at most, and not 40 or more times to evaluate the model once?

How do you actually deal with layers that are called more than once?

In any case, thanks for your time!

blefaudeux commented 3 years ago

actually the fp32 result is a bit strange, it's faster than fp16 ? I'm guessing that 1080 could be limited on that front

Yes, I also noticed that, it could indeed be a constraint on the 1080 Ti. If you like I can do an experiment on a AWS machine with Nvidia T4's.

as you prefer but don't bother too much on that front, we have fp16 data from quite a few sources !

blefaudeux commented 3 years ago

@blefaudeux One last note, I always was under the impression that this could somehow be related to the fact that the decoder module is called many times in a loop, because of the auto regressive nature of creating the sequential model output; when using "modern" transformer architectures you usually don't have to apply the model auto regressively during training.

Could this be something you have not tested? Or maybe the error wasn't perceivable because a layer, like an embedding layer, is maybe called twice at most, and not 40 or more times to evaluate the model once?

How do you actually deal with layers that are called more than once?

In any case, thanks for your time!

yeah it's a good point, there could be something in the architecture really. ShardedDDP only really deals with the backward, so if you're calling them many times in a row this should not change much. What could change though is how autograd produces the grads, if it calls the backward hook many times then it will just be reduced once, and maybe that this is problematic actually. I'll try on repro with that axis

blefaudeux commented 3 years ago

I've been unable to repro with multiple FW on the same block and an embedding table (linked branch), a unit test gives ShardeDDP a parity pass with DDP, so a priori it's something else. Looking into other options

visionscaper commented 3 years ago

@blefaudeux I've been looking at your test and compared it to what I do in my experiments. One thing that stands out is that I do optimizer.zero_grad() en you do model.zero_grad(). This might be the issue.

visionscaper commented 3 years ago

I added the results of an experiment with optimizer.zero_grad() changed to model.zero_grad(). It doesn't seem to matter.

visionscaper commented 3 years ago

I think I will need to make a minimal test to try to pin down the issue, now there is so much complexity in my code that could obfuscate the root cause. Maybe there is a simple explanation. I'm not sure when I have the time to make such a test, though. @blefaudeux, If you have any ideas in mean while, please let me know.

blefaudeux commented 3 years ago

@blefaudeux I've been looking at your test and compared it to what I do in my experiments. One thing that stands out is that I do optimizer.zero_grad() en you do model.zero_grad(). This might be the issue.

ah yes, this can be a problem with the shardedDDP buckets, if you do it on the optimizer it cannot reset the bucket gradients properly (could be changed to be done by default upon reduction actually, would make sense but a bit heavy handed), I think that it's mentioned in the doc but easy to miss. Without buckets it does not change anything though, so looks like it's not your issue. A repro would be ideal if it's not too much work, I could add that to the test suite, we already have gpt2 for instance because of specific issues it exposed