A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
628
stars
49
forks
source link
implicit inplace operation '*=' cause an error when deriving the back gradient in pytorch #6
Closed
VRCMF closed 2 years ago
In the code Error, it cause the failure of deriving the back gradient.
Solution: density_1_proxy = density_1_proxy*equals_one_mask[..., None]