Open MukundVarmaT opened 2 years ago
try it on cpu ? I can work top1 operation is non-differentiable, but the balance loss is based on logits of gating distribution and count num of tokens per expert, so actually the grad of weight should not be None
Hi @lucidrains
I am a little confused about how the parameters experts.w1 and experts.w2 are updated. The top1 operation is non-differentiable and therefore the gradients of these two parameters would be None. To confirm i even ran the following:
which gave the following output:
It would be really helpful if you could clarify my understanding. Thanks