uu-sml / sml-book-page

Page for the SML book
48 stars 16 forks source link

A plot of a non-constant learning rate for stochastic gradient descent #28

Closed dmitrijsk closed 3 years ago

dmitrijsk commented 3 years ago

A non-constant learning rate for the convergence of stochastic gradient is clearly explained in the section “5.3 Optimization with large datasets” and then applied in neural nets in “Example 6.3”. However, for me personally it took some thinking to realize that the learning rate $\gamma$ gets close to $\gamma_{min}$ only towards the 10000th iteration ($\approx$ 16.67 epochs). And after $\tau=2000$ iterations ($\approx$ 3.3 epochs) the learning rate is still far from the minimum ($\approx 0.0012$). To comprehend the shape of the learning rate I have made a plot below. Sharing it here together with the Python code just in case you find it useful for the book. I hope the computations are correct.

non-constant learning rate (gamma) for neural nets

import numpy as np
import matplotlib.pyplot as plt

gamma_min = 0.0001
gamma_max = 0.003
tau = 2000 # decay parameter.
n_data = 60000 # number of data points.
n_minibatch = 100 # size of the mini-batch.
n_iter = 10000 # iterations in training.

epoch_size = n_data / n_minibatch # computed size of 1 epoch.

iter_ = np.arange(1, n_iter+1)
gammas = [gamma_min + (gamma_max - gamma_min) * np.exp(-t/tau) for t in iter_] # gamma for each iteration.

# Plot the learning curve.
fig, ax = plt.subplots()
ax.plot(iter_, gammas, linewidth = 3.0)
ax.plot(iter_, [gamma_max]*n_iter, color = "black", linestyle = "dashed")
ax.plot(iter_, [gamma_min]*n_iter, color = "black", linestyle = "dashed")
ax.set(xlabel='', ylabel='Learning rate $\gamma$', title='Non-constant learning rate $\gamma$ as a function of iteration number')

# x axis tick labels to represent both epochs and iterations.
x_epochs = np.array([1, 5, 10, 15])
x_labels = [f'{e} epochs\n{e*epoch_size:.0f} iter\'s' for e in x_epochs]
x_iters = x_epochs * epoch_size
plt.xticks(x_iters, x_labels)

# Annotations for dashed curves.
ax.annotate("$\gamma_{min}$" + "=" + str(gamma_min), (0, gamma_min*2))
ax.annotate("$\gamma_{max}$" + "=" + str(gamma_max), (8000, gamma_max - gamma_min*2))
nikwa commented 3 years ago

Thanks for the comment, I thinks it is a great idea and adds clarification. Similar figure will be added in Example 6.3 in the next version.