stared / livelossplot

Live training loss plot in Jupyter Notebook for Keras, PyTorch and others
https://p.migdal.pl/livelossplot
MIT License
1.29k stars 143 forks source link

PlotLossesKeras has an issue with 'lr' #56

Closed mjkvaak closed 4 years ago

mjkvaak commented 5 years ago

I'm training a CNN model and I wanted to both a) be able to reduce the optimizer learning rate when it hits a plateau with ReduceLROnPlateau() and b) visually monitor the losses with livelossplot's PlotKerasLosses(). This combination in callbacks gives me a KeyError: 'lr'. When dropping either one of the functions the training works as it should. Hence my guess is that there must be a compatibility issue with the two.

Keras = 2.2.4 livelossplot = 0.3.4

Here's some code:

from keras.callbacks import ReduceLROnPlateau
from livelossplot import PlotLossesKeras

...

learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.00001)

history = model.fit_generator(datagen.flow(X_train,Y_train, batch_size=batch_size),
                              epochs = epochs, 
                              validation_data = (X_val,Y_val),
                              verbose = 1, 
                              steps_per_epoch=X_train.shape[0] // batch_size,
                              callbacks=[learning_rate_reduction, PlotLossesKeras()])

And here's the error I get:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-26-bd8223167e59> in <module>
      5                               verbose = 1,
      6                               steps_per_epoch=X_train.shape[0] // batch_size,
----> 7                               callbacks=[learning_rate_reduction, PlotLossesKeras()])

~/anaconda3/envs/upwork/lib/python3.7/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/anaconda3/envs/upwork/lib/python3.7/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1416             use_multiprocessing=use_multiprocessing,
   1417             shuffle=shuffle,
-> 1418             initial_epoch=initial_epoch)
   1419 
   1420     @interfaces.legacy_generator_methods_support

~/anaconda3/envs/upwork/lib/python3.7/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    249                     break
    250 
--> 251             callbacks.on_epoch_end(epoch, epoch_logs)
    252             epoch += 1
    253             if callback_model.stop_training:

~/anaconda3/envs/upwork/lib/python3.7/site-packages/keras/callbacks.py in on_epoch_end(self, epoch, logs)
     77         logs = logs or {}
     78         for callback in self.callbacks:
---> 79             callback.on_epoch_end(epoch, logs)
     80 
     81     def on_batch_begin(self, batch, logs=None):

~/anaconda3/envs/upwork/lib/python3.7/site-packages/livelossplot/generic_keras.py in on_epoch_end(self, epoch, logs)
     63 
     64     def on_epoch_end(self, epoch, logs={}):
---> 65         self.liveplot.update(logs.copy())
     66         self.liveplot.draw()

~/anaconda3/envs/upwork/lib/python3.7/site-packages/livelossplot/generic_plot.py in update(self, log)
     80         self.logs.append(log)
     81         if self.plot_extrema:
---> 82             self._update_extrema(log)
     83 
     84     def draw(self):

~/anaconda3/envs/upwork/lib/python3.7/site-packages/livelossplot/generic_plot.py in _update_extrema(self, log)
     69     def _update_extrema(self, log):
     70         for metric, value in log.items():
---> 71             extrema = self.metrics_extrema[metric]
     72             if _is_unset(extrema['min']) or value < extrema['min']:
     73                 extrema['min'] = float(value)

KeyError: 'lr'
stared commented 5 years ago
mjkvaak commented 5 years ago

I was tracking only accuracy:

model.compile(optimizer = RMSprop(),
                           loss = 'categorical_crossentropy',
                           metrics = ['accuracy'])

I haven't tried plot_extrema=False (if it is not a default option). I can try this later and let you know what's the result. It seems that this problem is not only together with ReduceLROnPlateau() but also happened with LearningRateScheduler().

mjkvaak commented 5 years ago

Hi again, sorry for the delayed answer. I experimented with plot_extrema=False in PlotKerasLosses() and can confirm that now everything works as it should.

stared commented 5 years ago

@mjkvaak thank you for testing that.

I think there is an issue with metrics containing an underscore.

@sebastienlange - could you look it, since you introduced https://github.com/stared/livelossplot/commit/6e7892c54851241055fa0f335e159ef6c3d6fe43 ?

sebastienlange commented 5 years ago

Hi, sorry for the late reply. Indeed, I can reproduce it, at least partially. I'll take more time soon to try figure it out.

stared commented 5 years ago

@mjkvaak Does it work at 0.4.1?

mjkvaak commented 5 years ago

@stared Nope, the same KeyError: 'lr' still persists at version 0.4.1 :/.

stared commented 5 years ago

@mjkvaak So, I tried to fix that with 96c4085252c60ec7a54251c7bbb02692057a0251. I don't have the full code, so it is not tested.

But if you uninstall and install livelossplot for this repo (update does not work, as it is still 0.4.1), it may help. I would be grateful for feedback on that.

It is a quick fix - in general, the extrema stuff need to be rewritten. But for now, it should work.

stared commented 4 years ago

The new release 0.5.0 should solve the problem. Metrics grouping was rewritten thoughtfully and is now much more configurable.

@mjkvaak and @sebastienlange - let me know it if works now.