NeuromorphicProcessorProject / snn_toolbox

Toolbox for converting analog to spiking neural networks (ANN to SNN), and running them in a spiking neuron simulator.
MIT License
360 stars 104 forks source link

Load keras with custom metrics #56

Closed RKCZ closed 4 years ago

RKCZ commented 4 years ago

Hello, is it possible to load a Keras model which was compiled with custom metrics? As I understand it, it is a similar issue as #50. Do I have to hardcode the custom metrics into SNN-TB? What should be changed?

rbodo commented 4 years ago

Currently, the toolbox is set up for classification, and compiles each model with two metrics: top-1 accuracy, and top-k accuracy, where k can be specified by the user and equals 1 by default.

It would be a welcome extension to make this behavior more flexible (allowing custom metrics) and support for instance object detection.

For now, a custom metric is possible with the following hacks:

  1. In file snntoolbox/parsing/utils.py, import your custom metric:

    from x.y import my_metric

  2. In the custom_objects dict within the function get_custom_activations_dict, add your metric to the dict:

    'my_metric': my_metric

  3. In build_parsed_model, replace the top-k metric by yours:

    self.parsed_model.compile(
        'sgd', 'categorical_crossentropy',
        ['accuracy', keras.metrics.top_k_categorical_accuracy])

becomes

    self.parsed_model.compile(
        'sgd', 'categorical_crossentropy', 
        self.input_model.metrics)

(Here I'm assuming your model has been trained with metrics=['accuracy', 'my_metric'].)

  1. Finally, in snntoolbox/simulation/utils.py, replace
    top5score_moving += sum(in_top_k(output_b_l_t[:, :, -1], truth_b, self.top_k))
    top5acc_moving = top5score_moving / num_samples_seen

by

top5acc_moving = keras.backend.get_value(my_metric(
    keras.backend.constant(truth_d), 
    keras.backend.constant(guesses_d)))

(Here I'm assuming you implemented your metric as keras/tf function. You can avoid the tf.constant conversion by using a python-version of your metric here.)

Now, when you run the toolbox, you will be seeing output like this:

Evaluating parsed model on 100 samples...
Top-1 accuracy: 100.00%  (<== accuracy)
Top-5 accuracy: 100.00%  (<== my_metric)

The first value is still the top-1 accuracy as before, but the second line now reports your metric. (Change the print function to whatever label you like.)

When simulating the SNN you will get:

Current accuracy of batch:
0.00%_10.00%_10.00%_0.00%_0.00%_10.00%_20.00%_40.00%_60.00%_60.00%_70.00%_90.00%_90.00%_100.00% (<== accuracy)
Moving accuracy of SNN (top-1, top-1): 100.00% (<== accuracy), 100.00% (<== my_metric).
Moving accuracy of ANN (top-1, top-1): 100.00% (<== accuracy), 100.00% (<== my_metric).

This recipe has been used for instance to support the "precision" metric. Again, a more sustainable implementation of this would be very welcome.

RKCZ commented 4 years ago

Thank you for detailed instructions and I am sorry it took me so much time to respond. I followed the directions but I cannot figure it out. The metric I am trying to use is keras.metrics.AUC(). The issue is that I don't know what value should be added into the custom_objects dict. I compile the original model with following command:

model.compile(
      optimizer=keras.optimizers.Adam(), loss=keras.losses.BinaryCrossentropy(),
      metrics=['binary_accuracy', keras.metrics.AUC(name='auc')])

and I tried to add mapping 'auc': keras.metrics.AUC() which resulted in exception ValueError: Unknown metric function: {'class_name': 'AUC', 'config': {'name': ... I tried to change it to 'auc': keras.metrics.AUC().update_status() and then I tried to create new function to wrap the metric:

def auc(y_true, y_pred):
  auc = keras.metrics.AUC()
  auc.update_state(y_true, y_pred)
  return auc.result().numpy()

but I always got similar error. Do you know how to use keras.metrics instances?

rbodo commented 4 years ago

I think the custom_objects mapping might be case-sensitive; try

'AUC': keras.metrics.AUC()

Also, it shouldn't make a difference to have name='auc' in the constructor - but to be safe I'd just leave that out at first.

By the way, using 'binary_accuracy' instead of 'accuracy' will result in unexpected behavior when testing the SNN (the toolbox assumes 'accuracy').

RKCZ commented 4 years ago

Thank you for pointing out that 'accuracy' must be specified instead of 'binary_accuracy'. There is still the same error even when I change the key to upper case.

rbodo commented 4 years ago

Don't know what it could be, will try to take a look later this week.

rbodo commented 4 years ago

I can get the model to compile using

'AUC': AUC

in the custom_objects dict (i.e. pass the class, not an instance of AUC).