PyTorch minimizes throughput degradation by overlapping communication and computation in distributed training.
However, Betty currently performs computation first and then manually perform gradient synchronization, not using the computation-communication overlapping technique.
This is mainly due to the fact that hypergradient calculation oftentimes requires second-order gradient computation as well as multiple forward-backward propagations.
To improve distributed training performance we can:
1) make the use of PyTorch's native communication-computation overlap by replacing torch.autograd.grad with torch.autograd.backward
2) keep most computations in hypergradient calculation local, and perform gradient synchronization at the end once.
PyTorch minimizes throughput degradation by overlapping communication and computation in distributed training. However, Betty currently performs computation first and then manually perform gradient synchronization, not using the computation-communication overlapping technique. This is mainly due to the fact that hypergradient calculation oftentimes requires second-order gradient computation as well as multiple forward-backward propagations. To improve distributed training performance we can: 1) make the use of PyTorch's native communication-computation overlap by replacing
torch.autograd.grad
withtorch.autograd.backward
2) keep most computations in hypergradient calculation local, and perform gradient synchronization at the end once.