brightmart / text_classification

all kinds of text classification models and more with deep learning
MIT License
7.83k stars 2.57k forks source link

🙊 fix macro f1 #115

Closed iofu728 closed 5 years ago

iofu728 commented 5 years ago

The previous code calculates macro f1 use average f1.

def compute_f1_macro_use_TFFPFN(label_dict):
    """
    compute f1_macro
    :param label_dict: {label:(TP,FP,FN)}
    :return: f1_macro
    """
    f1_dict= {}
    num_classes=len(label_dict)
    for label, tuplee in label_dict.items():
        TP,FP,FN=tuplee
        f1_score_onelabel=compute_f1(TP,FP,FN,'macro')
        f1_dict[label]=f1_score_onelabel
    f1_score_sum=0.0
    for label,f1_score in f1_dict.items():
        f1_score_sum=f1_score_sum+f1_score
    f1_score=f1_score_sum/float(num_classes)
    return f1_score

But It is not true. Macro f1 should calculate by average P & average R. Reference macro-f1-score-keras, Micro Average vs Macro average Performance in a Multiclass classification setting

So, using @jit to struct function to calculate f1_score.

@jit
def fastF1(result, predict):
    ''' f1 score '''
    true_total, r_total, p_total, p, r = 0, 0, 0, 0, 0
    total_list = []
    for trueValue in range(6):
        trueNum, recallNum, precisionNum = 0, 0, 0
        for index, values in enumerate(result):
            if values == trueValue:
                recallNum += 1
                if values == predict[index]:
                    trueNum += 1
            if predict[index] == trueValue:
                precisionNum += 1
        R = trueNum / recallNum if recallNum else 0
        P = trueNum / precisionNum if precisionNum else 0
        true_total += trueNum
        r_total += recallNum
        p_total += precisionNum
        p += P
        r += R
        f1 = (2 * P * R) / (P + R) if (P + R) else 0
        print(id2rela[trueValue], P, R, f1)
        total_list.append([P, R, f1])
    p /= 6
    r /= 6
    micro_r = true_total / r_total
    micro_p = true_total / p_total
    macro_f1 = (2 * p * r) / (p + r) if (p + r) else 0
    micro_f1 = (2 * micro_p * micro_r) / (micro_p +
                                          micro_r) if (micro_p + micro_r) else 0
    print('P: {:.2f}%, R: {:.2f}%, Micro_f1: {:.2f}%, Macro_f1: {:.2f}%'.format(
        p*100, r*100, micro_f1 * 100, macro_f1*100))

code link

brightmart commented 5 years ago

thank you

TingNLP commented 5 years ago

what is range(6) mean?

iofu728 commented 5 years ago

@TingNLP what is range(6) mean?

Very Sorry. It's a mistake causing by myself in the pull request.

It should be the num of classifying, but I set it to 6. Sorry for the confusion caused by the code.

I have fixed the eval process bug. And there are in the pull request #133.