Open tkbadamdorj opened 2 years ago
Setting reduction = 'sum' does not work because of this line:
if self._reduction: merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
because if reduction is a string, self._reduction is always True
See the forked version here with the issue corrected:
https://github.com/anzeyimana/Pytorch-PCGrad-GradVac-AMP-GradAccum
Setting reduction = 'sum' does not work because of this line:
because if reduction is a string, self._reduction is always True