lucidrains / mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
MIT License
628 stars 49 forks source link

Regarding experts.w1 and experts.w2 gradients #7

Open MukundVarmaT opened 2 years ago

MukundVarmaT commented 2 years ago

Hi @lucidrains

I am a little confused about how the parameters experts.w1 and experts.w2 are updated. The top1 operation is non-differentiable and therefore the gradients of these two parameters would be None. To confirm i even ran the following:

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
).cuda()
inputs = torch.randn(4, 1024, 512).cuda()
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
aux_loss.backward()
for name, param in moe.named_parameters():
    if param.grad is None:
        print(name)

which gave the following output:

experts.w1
experts.w2

It would be really helpful if you could clarify my understanding. Thanks

rattlesnakey commented 1 year ago

try it on cpu ? I can work top1 operation is non-differentiable, but the balance loss is based on logits of gating distribution and count num of tokens per expert, so actually the grad of weight should not be None