netrack / keras-metrics

Metrics for Keras. DEPRECATED since Keras 2.3.0
MIT License
165 stars 23 forks source link

sparse_categorical metrics do not work for tensorflow 2.7 #52

Closed henghamao closed 2 years ago

henghamao commented 2 years ago

This is the great work! Though tf and keras have official recall() and precision() metrics, the metrics only worked for binary classfication. We had a problem to classify three categories. And we would like to figure out the recall, precision metrics for the each class. In our model, the last layer is Dense layer with active function 'softmax'. The loss function is 'sparse_categorical_crossentropy', as we used class label for y.

output = Dense(3, activation='softmax')(attention_mul)
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy'])

The output of the prediction result is a vector of the probability of the each class, e.g. [0.3, 0.5, 0.2].To get the class label, we need to apply np.argmax() for the prediction results. And thus, the official recall and precision metrics do not work! I found there are sparse_categorical metrics in this project. I tried the following code to get the metric for the class label 1, and passed the metrics to model compile.

recall_1 = km.sparse_categorical_recall(label=1)
precision_1 = km.sparse_categorical_precision(label=1)
metrics = ['accuracy', recall_1, precision_1]

Here are error messages we got:

Epoch 1/2000
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:00.580176 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:00.776434 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
/usr/local/python3.8/lib/python3.8/site-packages/keras_metrics/metrics.py:26: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  return object.__getattribute__(self, name)
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:01.025048 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:01.229791 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:03.521585 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:03.548538 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:03.596331 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:03.622677 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:11.411682 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:11.435597 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:11.480696 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:`add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
W0901 17:11:11.506137 139653739816768 base_layer.py:1764] `add_update` `inputs` kwarg has been deprecated. You no longer need to pass a value to `inputs` as it is being automatically inferred.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.recall object at 0x7f02f39d6d30>, because it is not built.
W0901 17:11:14.928323 139653739816768 save_impl.py:71] Skipping full serialization of Keras layer <keras_metrics.metrics.recall object at 0x7f02f39d6d30>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.precision object at 0x7f02f2d1dac0>, because it is not built.
W0901 17:11:14.929263 139653739816768 save_impl.py:71] Skipping full serialization of Keras layer <keras_metrics.metrics.precision object at 0x7f02f2d1dac0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.true_positive object at 0x7f02f3701700>, because it is not built.
W0901 17:11:14.929968 139653739816768 save_impl.py:71] Skipping full serialization of Keras layer <keras_metrics.metrics.true_positive object at 0x7f02f3701700>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.false_negative object at 0x7f02f36ca8e0>, because it is not built.
W0901 17:11:14.930354 139653739816768 save_impl.py:71] Skipping full serialization of Keras layer <keras_metrics.metrics.false_negative object at 0x7f02f36ca8e0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.true_positive object at 0x7f02f2dc4970>, because it is not built.
W0901 17:11:14.930832 139653739816768 save_impl.py:71] Skipping full serialization of Keras layer <keras_metrics.metrics.true_positive object at 0x7f02f2dc4970>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras_metrics.metrics.false_positive object at 0x7f02f2d24580>, because it is not built.

Traceback (most recent call last):
  File "main.py", line 537, in <module>
    tf.compat.v1.app.run()
  File "/usr/local/python3.8/lib/python3.8/site-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/usr/local/python3.8/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/python3.8/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "main.py", line 476, in main
    model.train(stock_data_list, test_ratio)
  File "/root/work/branch/infinity_stock3/models/model_attention_three_category.py", line 101, in train
    self.model.fit(train_g, steps_per_epoch=math.ceil(train_len / self.batch_size),
  File "/usr/local/python3.8/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/python3.8/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py", line 53, in _python_properties_internal
    metadata.update(get_serialized(self.obj))
ValueError: dictionary update sequence element #0 has length 1; 2 is required
henghamao commented 2 years ago

Thanks to the project, I finally found out the solution by using the customized metric function. See my answer on the stackoverflow: https://stackoverflow.com/questions/73564461/recall-and-precision-metrics-for-multi-class-classification-in-tensorflow-keras/73633124#73633124

henghamao commented 2 years ago

See the answer on the above.