apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.76k stars 6.8k forks source link

mx.metric F1 is using numpy logic #9586

Open szha opened 6 years ago

szha commented 6 years ago

The metric module has been using numpy logic and is not benefiting from existing mxnet operators.

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/metric.py#L535-L569

sxjscience commented 6 years ago

Bring discussion to the correct place. I've implemented an ndarray version of F1 score when doing the experiments and I've included my nd_f1 in the following:

May be useful if we want to accelerate the F1 score computation in the future. Also, we can take advantage of the fact that the micro F1 is equivalent to accuracy for single-label classification to accelerate the computatoin.

import mxnet.ndarray as nd
from sklearn.metrics import f1_score
import numpy as np
import mxnet as mx
import time

def nd_f1(pred, label, num_class, average="micro"):
    """Evaluate F1 using mx.nd.NDArray

    Parameters
    ----------
    pred : nd.NDArray
        Shape (num, label_num) or (num,)
    label : nd.NDArray
        Shape (num, label_num) or (num,)
    num_class : int
    average : str

    Returns
    -------
    f1 : float
    """
    if pred.dtype != np.float32:
        pred = pred.astype(np.float32)
        label = label.astype(np.float32)
    assert num_class > 1
    assert pred.ndim == label.ndim
    if num_class == 2 and average == "micro":
        tp = nd.sum((pred == 1) * (label == 1)).asscalar()
        fp = nd.sum((pred == 1) * (label == 0)).asscalar()
        fn = nd.sum((pred == 0) * (label == 1)).asscalar()
        precision = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)
        f1 = 2 * (precision * recall) / (precision + recall)
    else:
        assert num_class is not None
        pred_onehot = nd.one_hot(indices=pred, depth=num_class)
        label_onehot = nd.one_hot(indices=label, depth=num_class)
        tp = pred_onehot * label_onehot
        fp = pred_onehot * (1 - label_onehot)
        fn = (1 - pred_onehot) * label_onehot
        if average == "micro":
            tp = nd.sum(tp).asscalar()
            fp = nd.sum(fp).asscalar()
            fn = nd.sum(fn).asscalar()
            precision = float(tp) / (tp + fp)
            recall = float(tp) / (tp + fn)
            f1 = 2 * (precision * recall) / (precision + recall)
        elif average == "macro":
            if tp.ndim == 3:
                tp = nd.sum(tp, axis=(0, 1))
                fp = nd.sum(fp, axis=(0, 1))
                fn = nd.sum(fn, axis=(0, 1))
            else:
                tp = nd.sum(tp, axis=0)
                fp = nd.sum(fp, axis=0)
                fn = nd.sum(fn, axis=0)
            precision = nd.mean(tp / (tp + fp)).asscalar()
            recall = nd.mean(tp / (tp + fn)).asscalar()
            f1 = 2 * (precision * recall) / (precision + recall)
        else:
            raise NotImplementedError
    return f1

for pred_npy, label_npy, num_class\
        in [(np.random.randint(0, 50, size=(100000,)),
             np.random.randint(0, 50, size=(100000,)),
             50),
            (np.random.randint(0, 2, size=(10000, 121)),
             np.random.randint(0, 2, size=(10000, 121)),
             2)]:
    # Test F1 score
    for average in ['micro', 'macro']:
        start = time.time()
        for _ in range(5):
            f1_npy = f1_score(y_true=label_npy, y_pred=pred_npy, average=average)
        end = time.time()
        print("Average=", average, "Npy Time Spent:", end - start)
        pred_nd = nd.array(pred_npy, ctx=mx.gpu(), dtype=np.float32)
        label_nd = nd.array(label_npy, ctx=mx.gpu(), dtype=np.float32)
        nd.waitall()
        f1_nd = nd_f1(pred=pred_nd,
                      label=label_nd,
                      num_class=num_class,
                      average=average)
        nd.waitall()
        start = time.time()
        for _ in range(5):
            f1_nd = nd_f1(pred=pred_nd,
                          label=label_nd,
                          num_class=num_class,
                          average=average)
            nd.waitall()
        end = time.time()
        print("Average=", average, "NDArray Time Spent:", end - start, 'abs diff:', abs(f1_nd - f1_npy))

Result:

Average= micro Npy Time Spent: 0.1795516014099121
Average= micro NDArray Time Spent: 0.033802032470703125 abs diff: 0.0
Average= macro Npy Time Spent: 0.17911505699157715
Average= macro NDArray Time Spent: 0.07393026351928711 abs diff: 4.64383991273e-06
Average= micro Npy Time Spent: 0.6379575729370117
Average= micro NDArray Time Spent: 0.029665708541870117 abs diff: 0.0
Average= macro Npy Time Spent: 0.6377367973327637
Average= macro NDArray Time Spent: 0.034937143325805664 abs diff: 0.000381544355229
szha commented 6 years ago

@sxjscience awesome. Would you propose a PR after #9777 is merged? If/when you do, remember to report the benchmark test results from #9705

sxjscience commented 6 years ago

OK, I'll PR after it's merged.