google-deepmind / sonnet

TensorFlow-based neural network library
https://sonnet.dev/
Apache License 2.0
9.79k stars 1.3k forks source link

Decaying learning rate #247

Closed isabellahuang closed 2 years ago

isabellahuang commented 2 years ago

In TF 1, one could use a decaying learning rate like the following:

lr = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, 
                                  global_step=global_step,
                                  decay_steps=int(5e6),
                                  decay_rate=0.1) + 1e-6
  optimizer = tf.train.AdamOptimizer(learning_rate=lr)

Is this possible in Sonnet 2?

tomhennigan commented 2 years ago

Hi @isabellahuang, Sonnet 2 doesn't ship with any pre-defined learning rate schedules, but you can pass a tf.Variable as a learning rate and then assign a value to that variable to the result of your schedule:

def exponential_decay(learning_rate, decay_steps, decay_rate):
  return lambda global_step: learning_rate * tf.math.pow(decay_rate, (global_step / decay_steps))

lr_schedule = exponential_decay(learning_rate, decay_steps, decay_rate)
global_step = tf.Variable(0)
learning_rate = tf.Variable(lr_schedule(global_step))
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

for batch in dataset:
  # .. snip ..
  learning_rate.assign(lr_schedule(global_step))
  optimizer.apply(updates, parameters)
  global_step.assign_add(1)

There is another example in our tests:

https://github.com/deepmind/sonnet/blob/d1cd37117bcb98223b3e4b930717d418abb76484/sonnet/src/optimizers/adam_test.py#L120-L125