netrack / keras-metrics

Metrics for Keras. DEPRECATED since Keras 2.3.0
MIT License
164 stars 23 forks source link

Example for how to use metric with multi-label (ValueError for more than 2 classes) #26

Open NumesSanguis opened 5 years ago

NumesSanguis commented 5 years ago

Since this issue (https://github.com/netrack/keras-metrics/pull/23) it should be possible to have a y array of longer than 2, however I cannot get it to work.

I have 5 classes, so I tried to use precision for 1 label with:

precision = keras_metrics.precision(label=0)

However, this results in the error:

~/anaconda3/envs/dl/lib/python3.6/site-packages/keras_metrics/metrics.py in _categorical(self, y_true, y_pred, dtype)
     46             return self._binary(y_true, y_pred, dtype, label=1)
     47         elif labels > 2:
---> 48             raise ValueError("With 2 and more output classes a "
     49                              "metric label must be specified")
     50 

ValueError: With 2 and more output classes a metric label must be specified

It seems like it only looks at the y shape, and not whether a label is specified?

Would this be better?:

def _categorical(self, y_true, y_pred, dtype, label=None):
    labels = y_pred.shape[-1]
    if labels == 2:
        label=1

    if labels > 2 and label:  # label != None
        raise ValueError("With 2 and more output classes a metric label must be specified")
    else:
        return self._binary(y_true, y_pred, dtype, label=label)
ybubnov commented 5 years ago

Hi @NumesSanguis, thank you for posting the issue.

Could you, please, post a model configuration or at least the last layer (example of output data), so I could understand why this fix is necessary? Thank you in advance.

Luux commented 5 years ago

@ybubnov I can confirm this issue. I tried with f1 = keras_metrics.f1_score(label=1) self.model.compile(optimizer="Adam", loss='binary_crossentropy', metrics=[f1]) and got the same error. My last layer currently is output_layer = keras.layers.Dense(self.n_classes, activation="softmax")(dense2)

The labels are 1-hot-encoded and have 12 classes, so the values have to be in the form np.array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])