rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
349 stars 130 forks source link

PT DistributedDataParallel with mixed precision training #1473

Open albertz opened 10 months ago

albertz commented 10 months ago

I noticed that the DistributedDataParallel module has the option mixed_precision which is for mixed precision training. We don't use that, even if the user specifies torch_amp to use mixed precision. So I wonder now, what happens if the user sets torch_distributed = {} (so using multi-GPU training via DistributedDataParallel) and also sets torch_amp = "bfloat16" (as an example)? Does this work correctly? Is this currently suboptimal? (Actually, I'm using that in some experiments, and memory consumption looks normal, just as in single-GPU training, but I did not really check carefully.)

albertz commented 10 months ago

@kuacakuaca @Judyxujj any idea? Done that before?

albertz commented 10 months ago

From the documentation on AMP on Working with Multiple GPUs, it sounds like it should already be fine in case we use DistributedDataParallel, one GPU per process.

kuacakuaca commented 10 months ago

@kuacakuaca @Judyxujj any idea? Done that before?

no, so you didn't observe a decrease in memory consumption?

albertz commented 10 months ago

so you didn't observe a decrease in memory consumption?

Compared to what? Of course, enabling AMP (I usually use bfloat16 without grad scaler) reduces GPU memory. But that is not really what I write here. This issue here is about distributed training. What I was saying is, going single GPU to multi GPU does not reduce memory. Why should it? But that is also not really my question here, that was just an observation, which could be relevant. My question was whether it actually works correctly. This observation is a hint that it probably uses AMP in some way also with distributed training (otherwise it would not be the same memory consumption as single GPU with AMP), but still I'm not sure if it does it correctly w.r.t. the distributed training. In AMP, the gradients are then also bfloat16? So AMP with distributed training, it means it would allreduce the bfloat16 gradients? So it should also save communication bandwidth? Or maybe it allreduces the wrong gradients, and multi GPU is effectively not correctly used here? This is my actual question.

Judyxujj commented 10 months ago

@albertz I used mixed precision training along with torch distributed training in fairseq framework to train the wav2vec2 model. With mixed precision, the training gets speed up. But I am haven't looked into the implementation details.