Closed isabellahuang closed 2 years ago
Hi @isabellahuang, Sonnet 2 doesn't ship with any pre-defined learning rate schedules, but you can pass a tf.Variable
as a learning rate and then assign a value to that variable to the result of your schedule:
def exponential_decay(learning_rate, decay_steps, decay_rate):
return lambda global_step: learning_rate * tf.math.pow(decay_rate, (global_step / decay_steps))
lr_schedule = exponential_decay(learning_rate, decay_steps, decay_rate)
global_step = tf.Variable(0)
learning_rate = tf.Variable(lr_schedule(global_step))
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
for batch in dataset:
# .. snip ..
learning_rate.assign(lr_schedule(global_step))
optimizer.apply(updates, parameters)
global_step.assign_add(1)
There is another example in our tests:
In TF 1, one could use a decaying learning rate like the following:
Is this possible in Sonnet 2?