learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.68k stars 354 forks source link

In distributed_maml.py, when are the gradients communicated across gpus? #391

Closed jkang1640 closed 1 year ago

jkang1640 commented 1 year ago

Hello,

Thanks to the example code, I could impelement maml with ddp for a seq2seq model.

While implementing the code, a question came up about the timing for gradients reducing. When we use a DDP wrapper for a model, every backward() steps implicitly reduces gradients across gpus, if I understood correctly.

In the example code I guess that cherry opt.step() # averages gradients across all workers does the job, as the comment says. Does it mean that gradients are not reduced at any moment before then, even at backward() ?

Thank you very much for your help!

seba-1511 commented 1 year ago

Hi @jkang1640,

The distributed optimizer in cherry is an alternative to DDP, so you don't need to wrap your model to distribute it. If you did, I believe it would simply average the gradients twice which leads to unnecessary communication overhead.

The example you linked uses cherry and not DDP because I'm not sure how DDP's backward reduction would work with second-order gradients -- I just haven't tested it thoroughly.