titu1994 / keras-one-cycle

Implementation of One-Cycle Learning rate policy (adapted from Fast.ai lib)
MIT License
285 stars 79 forks source link

One Cycle Learning Rate Policy for Keras

Implementation of One-Cycle Learning rate policy from the papers by Leslie N. Smith.

Contains two Keras callbacks, LRFinder and OneCycleLR which are ported from the PyTorch Fast.ai library.

What is One Cycle Learning Rate

It is the combination of gradually increasing learning rate, and optionally, gradually decreasing the momentum during the first half of the cycle, then gradually decreasing the learning rate and optionally increasing the momentum during the latter half of the cycle.

Finally, in a certain percentage of the end of the cycle, the learning rate is sharply reduced every epoch.

The Learning rate schedule is visualized as :

The Optional Momentum schedule is visualized as :

Usage

Finding a good learning rate

Use LRFinder to obtain a loss plot, and visually inspect it to determine the initial loss plot. Provided below is an example, used for the MiniMobileNetV2 model.

An example script has been provided in find_lr_schedule.py inside the models/mobilenet/.

Essentially,

from clr import LRFinder

lr_callback = LRFinder(num_samples, batch_size,
                       minimum_lr, maximum_lr,
                       # validation_data=(X_val, Y_val),
                       lr_scale='exp', save_dir='path/to/save/directory')

# Ensure that number of epochs = 1 when calling fit()
model.fit(X, Y, epochs=1, batch_size=batch_size, callbacks=[lr_callback])

The above callback does a few things.

Note : When using this, be careful about setting the learning rate, momentum and weight decay schedule. The loss plots will be more erratic due to the sampling of the validation set.

NOTE 2 :

To visualize the plot, there are two ways -

Finding the optimal Momentum

Use the find_momentum_schedule.py script inside models/mobilenet/ for an example.

Some notes :

Finding the optimal Weight Decay

Use the find_weight_decay_schedule.py script inside models/mobilenet/ for an example

Some notes :

Interpreting the plot

Learning Rate

Consider the above plot from using the LRFinder on the MiniMobileNetV2 model. In particular, there are a few regions above that we need to carefully interpret.

Note : The values are in log 10 scale (since exp was used for lr_scale) ; All values discussed will be based on the x-axis (learning rate) :

Momentum

Using the above learning rate, use this information to next calculate the optimal momentum (find_momentum_schedule.py)

See the notes in the Finding the optimal momentum section on how to interpret the plot.

Weight Decay

Similarly, it is possible to use the above learning rate and momentum values to calculate the optimal weight decay (find_weight_decay_schedule.py).

Note : Due to large learning rates acting as a strong regularizer, other regularization techniques like weight decay and dropout should be decreased significantly to properly train the model.

It is best to search a range of regularization strength between 1e-3 to 1e-7 first, and then fine-search the region that provided the best overall plot.

See the notes in the Finding the optimal weight decay section on how to interpret the plot.

Training with OneCycleLR

Once we find the maximum learning rate, we can then move onto using the OneCycleLR callback with SGD to train our model.

from clr import OneCycleLR

lr_manager = OneCycleLR(num_samples, num_epoch, batch_size, max_lr
                        end_percentage=0.1, scale_percentage=None,
                        maximum_momentum=0.95, minimum_momentum=0.85)

model.fit(X, Y, epochs=EPOCHS, batch_size=batch_size, callbacks=[model_checkpoint, lr_manager], 
          ...)

There are many parameters, but a few of the important ones :

Results

For the MiniMobileNetV2 model, 2 passes of the OneCycle LR with SGD (40 epochs - max lr = 0.02, 30 epochs - max lr = 0.005) obtained 90.33%. This may not seem like much, but this is a model with only 650k parameters, and in comparison, the same model trained on Adam with initial learning rate 2e-3 did not converge to the same score in over 100 epochs (89.14%).

Requirements