leopard-ai / betty

Betty: an automatic differentiation library for generalized meta-learning and multilevel optimization
https://leopard-ai.github.io/betty/
Apache License 2.0
329 stars 27 forks source link

[Request] Improve distributed training performance #8

Closed sangkeun00 closed 1 year ago

sangkeun00 commented 1 year ago

PyTorch minimizes throughput degradation by overlapping communication and computation in distributed training. However, Betty currently performs computation first and then manually perform gradient synchronization, not using the computation-communication overlapping technique. This is mainly due to the fact that hypergradient calculation oftentimes requires second-order gradient computation as well as multiple forward-backward propagations. To improve distributed training performance we can: 1) make the use of PyTorch's native communication-computation overlap by replacing torch.autograd.grad with torch.autograd.backward 2) keep most computations in hypergradient calculation local, and perform gradient synchronization at the end once.

sangkeun00 commented 1 year ago

Done