apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Better to flatten the label array in metric.F1() #16880

Open zburning opened 4 years ago

zburning commented 4 years ago

Description

Unlike the other metrics, the current metric.F1() doesn't flatten the label. Commonly the label would have the shape of (batch_size, 1), so it could be better to flatten it inside the F1() method. Otherwise it will give a wrong result without error, which could be hard to debug.

Error Message

(Paste the complete error message, including stack trace.)

To Reproduce

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

1. 2.

What have you tried to solve it?

1. 2.

Environment

We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below:

curl --retry 10 -s https://raw.githubusercontent.com/dmlc/gluon-nlp/master/tools/diagnose.py | python

# paste outputs here
sxjscience commented 4 years ago

We will have label shape = (B, N_labels) in multi-label classification problems, e.g., the PPI dataset used in Graph Neural Networks (https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf), and also multi-class classification, e.g., image classification.

Flattening the labels is thus not an option and you may refer to the definition in scikit_learn: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html and the implementation here https://github.com/apache/incubator-mxnet/issues/9586

zburning commented 4 years ago

Thank you for explanation! So the current implementation in metric.F1() is not good because it only support binary classification problem and label shape = (B, ). I can refer to the #9586 implementation for my scripts then.

sxjscience commented 4 years ago

@zburning I think the guideline for the refactoring is to try to follow the convention of scikit_learn.

sxjscience commented 4 years ago

@zburning Also, would you provided an example of this issue that was labeled as bug? This helps us understand the problem. Currently, I find that this one looks more like a feature request than a bug. Just correct me if I'm wrong.

zburning commented 4 years ago

@sxjscience Sorry I didn't make it clear. An example script is:

import mxnet as mx
pred = mx.nd.array([[-2.4738965  , 2.7095912 ],
 [ 1.4827207 , -1.6053244 ],
 [ 0.66689086 ,-1.0119148 ],
 [ 0.54501575 ,-0.8739182 ],
 [ 1.7229283  ,-1.80466   ],
 [-2.1540372  , 2.3391898 ],
 [-0.574123   , 0.18217295],
 [-1.5451021  , 1.3035003 ],
 [-2.366786   , 2.5836499 ],
 [-2.469643   , 2.6291811 ]])

label = mx.nd.array([[1],
 [0],
 [0],
 [1],
 [0],
 [1],
 [0],
 [1],
 [1],
 [1]])

print(pred.shape, label.shape) # pred shape: (10, 2), label shape: (10, 1)

metric = mx.metric.F1()
metric.update([label], [pred])
print(metric.get()) # ('f1', 0.6)

metric.reset()
metric.update([label.reshape(-1)], [pred]) # label shape: (10,)
print(metric.get()) # ('f1', 0.8333333333333334) This one is the correct result

The current F1() uses the _BinaryClassificationMetrics() class to update the stats. So in _BinaryClassificationMetrics.update_binary_stats(), it has:

 pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
pred_label = numpy.argmax(pred, axis=1)
check_label_shapes(label, pred)

The problem is that numpy.argmax(pred, axis=1) returns an array of shape (batch, ), the following computing method requires the label be the same size, i.e, the label should also be (batch, ). Also the following check_label_shapes() actually does nothing because the key argument "shape" is set to False by default. So the function can run without error but return a wrong result. It is easy to solve but since you mentioned refactoring it to support multi-label classification, we may not rely on the _BinaryClassificationMetrics() in the future? But anyway I think the current setting in _BinaryClassificationMetrics() is not good and actually other metrics(e.g. MCC()) will also suffer this problem potentially.

samskalicky commented 4 years ago

@zachgk assign [@mli ]

sxjscience commented 4 years ago

@zburning Are you willing to fix this problem?

zburning commented 4 years ago

@sxjscience Yes, I would like to work on it.