WeiChengTseng / Pytorch-PCGrad

Pytorch reimplementation for "Gradient Surgery for Multi-Task Learning"
BSD 3-Clause "New" or "Revised" License
302 stars 42 forks source link

reduction is always 'mean' #14

Open tkbadamdorj opened 2 years ago

tkbadamdorj commented 2 years ago

Setting reduction = 'sum' does not work because of this line:

if self._reduction:
    merged_grad[shared] = torch.stack([g[shared]
                                   for g in pc_grad]).mean(dim=0)

because if reduction is a string, self._reduction is always True

anzeyimana commented 2 years ago

See the forked version here with the issue corrected:

https://github.com/anzeyimana/Pytorch-PCGrad-GradVac-AMP-GradAccum