Closed GeophyAI closed 1 month ago
Hello @GeophyAI,
It depends on the optimizer. We cannot help without a concrete example of what you are trying to do.
I'm using optax.masked and optax.chain for assigning different learning rate to different paramemter groups, something like the following implementation:
paras_counts = len(pars_need_by_eq)
def create_mask_fn(index, num_params):
return tuple(i == index for i in range(num_params))
optimizers = []
for i, para in enumerate(pars_need_by_eq):
# Set the learning rate for each parameter
_lr = 0. if para not in pars_need_invert else lr[para]#*scale_decay**idx_freq
lr_schedule = optax.exponential_decay(_lr*scale_decay**idx_freq,1,epoch_decay)
# opt = optax.adam(lr_schedule, eps=1e-22)
opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22)
self.logger.print(f"Learning rate for {para}: {lr_schedule(0)}")
mask = create_mask_fn(i, paras_counts)
optimizers.append(optax.masked(opt, mask))
return optax.chain(*optimizers)
When my model only have 1 group parameter, it works fine, when the parameter is larger than 1, the updates always be zeros.
I found that when I replace the line opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22)
with opt = optax.adam(lr_schedule, eps=1e-22)
, it works for both single and multi parameter groups
I'm using optax for training something,
updates, opt_state = opt.update(gradient, opt_state)
, when I check the updates, I found that it contains only zeros, but the gradients do have values. In which situations this goona be happen?