apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Modify params in callback #2759

Closed magic282 closed 8 years ago

magic282 commented 8 years ago

Hi,

I am trying to write a callback that can roll back to the best model when the evaluation metric keeps going down. However I observed that it cannot successfully modify the params.

For example, the perplexity at epoch 32 is 3.7, and it decreases to 2.7 at epoch 42. After I load the params back to epoch 32, the perplexity is still 2.7. So I guess I failed to roll back. Is it the right way to change the model's params in callback components?

Thank you.

class CheckBleu(object):
    def __init__(self, model_name, optimizer, lr_dict):
        self.best_bleu = -1.0
        self.best_epoch = -1
        self.model_name = model_name
        self.optimizer = optimizer
        self.lr_dict = lr_dict
        self.cur_lr = 1.0
        self.continue_drop_epoch = 0

    def __call__(self, epoch, symbol, arg_params, aux_params):
        # epoch_end_callback(epoch, symbol, arg_params, aux_params)
        if epoch < 3:
            print('Too early to check BLEU at epoch {0}'.format(epoch))
            return
        print('Checking BLEU for epoch {0}'.format(epoch))
        gold = data_sets['dev'][0]
        test = data_sets['dev'][1]

        model_buckets = get_inference_models(arg_params)

        test_iwslt(gold, test, model_buckets)

        cur_bleu = get_bleu(gold, test)
        if cur_bleu > self.best_bleu:
            print('Current BLEU: {0} > prev best {1} in epoch {2}'.format(cur_bleu, self.best_bleu, self.best_epoch))
            self.best_bleu = cur_bleu
            self.best_epoch = epoch
            self.continue_drop_epoch = 0
        else:
            self.continue_drop_epoch += 1
            print('Current BLEU: {0} < prev best {1} in epoch {2}'.format(cur_bleu, self.best_bleu, self.best_epoch))
            if self.continue_drop_epoch >= 5:
                print('Rolling back to prev best epoch {0}...'.format(self.best_epoch))
                _, _arg_params, __ = mx.model.load_checkpoint(self.model_name, self.best_epoch + 1)
                for name, array in _arg_params.items():
                    array.copyto(arg_params[name])
                print('Halving lr... from {0} to {1}'.format(self.cur_lr, self.cur_lr * 0.5))
                self.cur_lr *= 0.5
                for k, v in self.lr_dict.items():
                    self.lr_dict[k] = self.cur_lr
                self.optimizer.set_lr_mult(self.lr_dict)
                self.continue_drop_epoch = 0
piiswrong commented 8 years ago

No. The callback is given a copy

magic282 commented 8 years ago

@piiswrong So is there any other way to roll back during training? Thanks.

mbrookhart commented 7 years ago

Bump? Did anyone ever figure out how to roll back during training?