netrack / keras-metrics

Metrics for Keras. DEPRECATED since Keras 2.3.0
MIT License
165 stars 23 forks source link

What does the `reset_states` do? #18

Closed jyhong836 closed 5 years ago

jyhong836 commented 5 years ago

It seems that the method reset_state in metrics resets the stored values. However, I am not sure when it should be used. Is it for resetting states at the end of each epoch?

According to my understanding, the keras-metrics is designed to avoid the incorrect approximation of recall on each batch. Thus, a practical solution is computing the metrics on the end of each epoch independently.

But in the README.md, the given example is

import keras
import keras_metrics

model = models.Sequential()
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
model.add(keras.layers.Dense(1, activation="softmax"))

model.compile(optimizer="sgd",
              loss="binary_crossentropy",
              metrics=[keras_metrics.precision(), keras_metrics.recall()])

which directly pass the keras_metrics.recall() as metrics for batch-based usage. The problem in the demo is that the states may* not be resetted. Therefore, the recall value of each epoch will be dependent on previous epochs.

* I am not sure if the reset_states method is called at the end of each epoch.

ybubnov commented 5 years ago

Hi @jyhong836. The reset_state indeed is called on each epoch, see https://github.com/keras-team/keras/blob/2.2.4/keras/engine/training_arrays.py#L145.

It seems the only thing with that: metrics should be marked as stateful, while they are not. Thank you for noticing that, I'll prepare appropriate changes.

jyhong836 commented 5 years ago

Thank you for your reply. You are right. I didn't notice the calling in training_arrays.py.

BTW, the unit test case is actually unconvincing. The correctness of the true positive, false negative and etc. values are not tested.

There has been a stateful metric test, i.e., test_stateful_metrics, in the official package but only for binary true positive test. You may refer to that.

Your package is really useful. Thank you for your contribution.

jyhong836 commented 5 years ago

I try to run the unit test in the pull request #19 . I compare the false_positive value aganist below function:

def ref_false_pos(y_true, y_pred):
    return np.sum(np.logical_and(np.round(y_pred)==1, y_true == 0))
y_pred = model.predict(x)
expected_fp = ref_false_pos(y, y_pred)

The values will not be equal occassionally. Even if I fixed the random seed by np.random.seed(2334), the inequality still happens occasionally.

Is there any explanation for this stochastics?

ybubnov commented 5 years ago

@jyhong836, could you, please post an example of run with failing test (output or input data). Unfortunately, I can't reproduce this issue after merging pull request #19.

jyhong836 commented 5 years ago

@ybubnov I post my test at jyhong836/keras-metrics. But I am not sure if you can reproduce the result. I also include a temp_model.hdf5 file. Please put it under your working directory.

Currently, I can reproduce the error on my Macbook, macos 10.14, tensorflow 1.5.0 (cpu version), keras 2.2.4 & 2.1.6. However, I cannot reproduce it on another linux computer, with tensorflow 1.4.0 (GPU version), keras 2.1.6. I am not sure if it is the version issue or the computer issue.

ybubnov commented 5 years ago

@jyhong836, I've tried to run your tests on my Linux machine and I've managed to reproduce an issue with the tensorflow 1.5.0 version. There is no issue with tensoflow 1.8.0 though.

ybubnov commented 5 years ago

Tests are failing both when model is loaded from temp_model.hdf5 and after model fitting.

jyhong836 commented 5 years ago

So I guess there is some bug in tensorflow <=1.5.0.

Which tensorflow do you use? GPU or CPU version?

ybubnov commented 5 years ago

I'm able to reproduce an issue with tensorflow 1.6.0 and 1.7.0 as well. I'm using CPU version (from pip repository, so not optimized for AVX2 and FMA instructions).

jyhong836 commented 5 years ago

I think there is no better solution than upgrading. I will close the issue.