Checkpointer doesn't resume current learning rate #225

Guriido commented 6 years ago

I have found this issue when resuming a stopped training. I use the ExponentialShift extension to divide learning rate during training, but when resuming the training, the learning rate is reinitialized to the start value (and not the maybe multiple times shifted value).

The workaround I use is to manually set the initial learning rate to the value of the latest checkpoint...

keisukefukuda commented 6 years ago

Sorry for the late response. I will investigate the issue soon.

kuenishi commented 6 years ago

@Guriido It seems rather serializer issue in ExponentialShiftrather than ChainerMN's checkpointer, as it just calls trainer serializer (and consequently all owned objects). Could you make sure that the issue is not reproducible without ChainerMN? If so, I'd be happy to test it once minimal reproducible script be provided.

Guriido commented 6 years ago

Sorry for the late answer. I tested many things but couldn't reproduce the issue without ChainerMN (I cannot affirm my test were exhaustive though).

With the following script (a modified version of mnist example ), thanks to the custom trigger, the learning rate is shifted after the second epoch. If I stop the training afterwards (at the fourth epoch for example) and resume the training by running the same script (with the same parameters of course), the learning rate is reset to the initial value ( 0.1 ) and not the expected shifted value ( 0.01 )

#!/usr/bin/env python
from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from import extensions
from mpi4py import MPI
import chainermn
from import util

class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            # the size of the inputs to each layer will be inferred
            l1=L.Linear(784, n_units),  # n_in -> n_units
            l2=L.Linear(n_units, n_units),  # n_units -> n_units
            l3=L.Linear(n_units, n_out),  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--epoch', '-e', type=int, default=60,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--double_buffering', action='store_true', help='improves speed')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.
    if args.double_buffering:
        args.communicator = 'pure_nccl'

    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank

    if comm.mpi_comm.rank == 0:
        print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num unit: {}'.format(args.unit))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))

    model = L.Classifier(MLP(args.unit, 10))
    if device >= 0:

    initial_lr = 0.1

    # Create a multi node optimizer from a standard Chainer optimizer.
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(lr=initial_lr, momentum=0.9), comm, double_buffering=args.double_buffering)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        train, test = chainer.datasets.get_mnist()
        train = None
        test = None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    test = chainermn.scatter_dataset(test, comm)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=False)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Create a multi node evaluator from a standard Chainer evaluator.
    evaluator = extensions.Evaluator(test_iter, model, device=device)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)

    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

        'lr', 0.1), trigger=LRShiftTrigger())

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        trainer.extend(extensions.observe_lr(), trigger=(1, 'epoch'))
            ['epoch', 'main/loss', 'validation/main/loss',
             'main/accuracy', 'validation/main/accuracy', 'elapsed_time', 'lr']))

class LRShiftTrigger(object):

    """Trigger invoked on specific epoch defined by ResNet Paper author

        key (str): Key of value.
        compare (function): Compare function which takes current best value and
            new value and returns whether new value is better than current
        trigger: Trigger that decides the comparison interval between current
            best value and new value. This must be a tuple in the form of
            ``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to


    triggers = [2, 50]

    def __init__(self):
        self._interval_trigger = util.get_trigger((1, 'epoch'))

    def __call__(self, trainer):
        """Decides whether the extension should be called on this iteration.

            trainer ( Trainer object that this
                trigger is associated with. The ``observation`` of this trainer
                is used to determine if the trigger should fire.

            bool: ``True`` if the corresponding extension should be invoked in
            this iteration.


        if not self._interval_trigger(trainer):
            return False

        epoch = trainer.updater.epoch
        return epoch in LRShiftTrigger.triggers

if __name__ == '__main__':

My environment settings (using Docker): Ubuntu 16.04 python 3.5.2 Open MPI 2.1.2 with infiniband mpi4py 3.0.0 chainermn 1.2.0 chainer 4.0.0b3

If you need any other informations, I will be more than happy to help

kuenishi commented 6 years ago

@Guriido Thank you for reporting, and for your effort to cut out a reproducible script! I successfully (?) reproduced the bug and will work on it.

kuenishi commented 6 years ago

I digged in a little bit deeper, I found optimizer's update rules states are actually not saved. I'll keep this open while is open.

kuenishi commented 6 years ago

@Guriido I have understood what is going on in your example code.

    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

        'lr', 0.1), trigger=LRShiftTrigger())

In your code the exponential shift extension is set after loading the snapshot. So the re-loaded trainer has correct learning rate but the later extension injection overwrites the reloaded learning rate with initial 0.1. I'd recommend to put snapshot-related code right before to avoid any stateful extensions' initialization after checkpoint reload like this:

        'lr', 0.1), trigger=LRShiftTrigger())


    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

This code correctly worked in my environment.

Guriido commented 6 years ago

Thank you very much for your time and efforts. I indeed did not expect ExponentialShift extension to affect the trainer status...

Maybe would it be proficient to put a warning or a note about this in the multi_node_checkpointer documentation? There is obvously a remark about calling it before, but I think it could avoid troubles for users if there is a mention like "it is recommended to call the load right before running trainer".

Guriido commented 6 years ago

added in #264 sorry for the trouble, thanks !