google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

updates returns zeros #1101

Closed GeophyAI closed 1 month ago

GeophyAI commented 1 month ago

I'm using optax for training something, updates, opt_state = opt.update(gradient, opt_state), when I check the updates, I found that it contains only zeros, but the gradients do have values. In which situations this goona be happen?

vroulet commented 1 month ago

Hello @GeophyAI,

It depends on the optimizer. We cannot help without a concrete example of what you are trying to do.

GeophyAI commented 1 month ago

I'm using optax.masked and optax.chain for assigning different learning rate to different paramemter groups, something like the following implementation:

paras_counts = len(pars_need_by_eq)

def create_mask_fn(index, num_params):
    return tuple(i == index for i in range(num_params))
optimizers = []
for i, para in enumerate(pars_need_by_eq):
    # Set the learning rate for each parameter
    _lr = 0. if para not in pars_need_invert else lr[para]#*scale_decay**idx_freq
    lr_schedule = optax.exponential_decay(_lr*scale_decay**idx_freq,1,epoch_decay)
    # opt = optax.adam(lr_schedule, eps=1e-22)
    opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22)
    self.logger.print(f"Learning rate for {para}: {lr_schedule(0)}")
    mask = create_mask_fn(i, paras_counts)
    optimizers.append(optax.masked(opt, mask))

return optax.chain(*optimizers)

When my model only have 1 group parameter, it works fine, when the parameter is larger than 1, the updates always be zeros.

GeophyAI commented 1 month ago

I found that when I replace the line opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22) with opt = optax.adam(lr_schedule, eps=1e-22), it works for both single and multi parameter groups