Closed Guriido closed 6 years ago
Sorry for the late response. I will investigate the issue soon.
@Guriido It seems rather serializer issue in ExponentialShift
rather than ChainerMN's checkpointer, as it just calls trainer serializer (and consequently all owned objects). Could you make sure that the issue is not reproducible without ChainerMN? If so, I'd be happy to test it once minimal reproducible script be provided.
Sorry for the late answer. I tested many things but couldn't reproduce the issue without ChainerMN (I cannot affirm my test were exhaustive though).
With the following script (a modified version of mnist example ), thanks to the custom trigger, the learning rate is shifted after the second epoch. If I stop the training afterwards (at the fourth epoch for example) and resume the training by running the same script (with the same parameters of course), the learning rate is reset to the initial value ( 0.1
) and not the expected shifted value ( 0.01
)
#!/usr/bin/env python
from __future__ import print_function
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI
import chainermn
from chainer.training import util
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__(
# the size of the inputs to each layer will be inferred
l1=L.Linear(784, n_units), # n_in -> n_units
l2=L.Linear(n_units, n_units), # n_units -> n_units
l3=L.Linear(n_units, n_out), # n_units -> n_out
)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--epoch', '-e', type=int, default=60,
help='Number of sweeps over the dataset to train')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
parser.add_argument('--double_buffering', action='store_true', help='improves speed')
args = parser.parse_args()
# Prepare ChainerMN communicator.
if args.double_buffering:
args.communicator = 'pure_nccl'
comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank
if comm.mpi_comm.rank == 0:
print('==========================================')
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
if args.gpu:
print('Using GPUs')
print('Using {} communicator'.format(args.communicator))
print('Num unit: {}'.format(args.unit))
print('Num Minibatch-size: {}'.format(args.batchsize))
print('Num epoch: {}'.format(args.epoch))
print('==========================================')
model = L.Classifier(MLP(args.unit, 10))
if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()
initial_lr = 0.1
# Create a multi node optimizer from a standard Chainer optimizer.
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.MomentumSGD(lr=initial_lr, momentum=0.9), comm, double_buffering=args.double_buffering)
optimizer.setup(model)
# Split and distribute the dataset. Only worker 0 loads the whole dataset.
# Datasets of worker 0 are evenly split and distributed to all workers.
if comm.rank == 0:
train, test = chainer.datasets.get_mnist()
else:
train = None
test = None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=False)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Create a multi node evaluator from a standard Chainer evaluator.
evaluator = extensions.Evaluator(test_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
trainer.extend(evaluator)
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
# Some display and output extensions are necessary only for one worker.
# (Otherwise, there would just be repeated outputs.)
if comm.rank == 0:
trainer.extend(extensions.LogReport())
trainer.extend(extensions.observe_lr(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time', 'lr']))
trainer.extend(extensions.ProgressBar())
trainer.run()
class LRShiftTrigger(object):
"""Trigger invoked on specific epoch defined by ResNet Paper author
Args:
key (str): Key of value.
compare (function): Compare function which takes current best value and
new value and returns whether new value is better than current
best.
trigger: Trigger that decides the comparison interval between current
best value and new value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
triggers = [2, 50]
def __init__(self):
self._interval_trigger = util.get_trigger((1, 'epoch'))
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (~chainer.training.Trainer): Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
Returns:
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
"""
if not self._interval_trigger(trainer):
return False
epoch = trainer.updater.epoch
return epoch in LRShiftTrigger.triggers
if __name__ == '__main__':
main()
My environment settings (using Docker): Ubuntu 16.04 python 3.5.2 Open MPI 2.1.2 with infiniband mpi4py 3.0.0 chainermn 1.2.0 chainer 4.0.0b3
If you need any other informations, I will be more than happy to help
@Guriido Thank you for reporting, and for your effort to cut out a reproducible script! I successfully (?) reproduced the bug and will work on it.
I digged in a little bit deeper, I found optimizer's update rules states are actually not saved. I'll keep this open while https://github.com/chainer/chainer/issues/4749 is open.
@Guriido I have understood what is going on in your example code.
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
In your code the exponential shift extension is set after loading the snapshot. So the re-loaded trainer has correct learning rate but the later extension injection overwrites the reloaded learning rate with initial 0.1. I'd recommend to put snapshot-related code right before trainer.run()
to avoid any stateful extensions' initialization after checkpoint reload like this:
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
(snip)
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.run()
This code correctly worked in my environment.
Thank you very much for your time and efforts. I indeed did not expect ExponentialShift extension to affect the trainer status...
Maybe would it be proficient to put a warning or a note about this in the multi_node_checkpointer
documentation?
There is obvously a remark about calling it before trainer.run(), but I think it could avoid troubles for users if there is a mention like "it is recommended to call the load right before running trainer".
added in #264 sorry for the trouble, thanks !
I have found this issue when resuming a stopped training. I use the
ExponentialShift
extension to dividelearning rate
during training, but when resuming the training, the learning rate is reinitialized to the start value (and not the maybe multiple times shifted value).The workaround I use is to manually set the initial learning rate to the value of the latest checkpoint...