brightmart / text_classification

all kinds of text classification models and more with deep learning
MIT License
7.87k stars 2.57k forks source link

🤐 fix textCNN eval process #133

Closed iofu728 closed 5 years ago

iofu728 commented 5 years ago

This a bug fix pull request. The bug comes from the 115 pull request(#115 ) caused by me.

In that pull request, I change the eval process to a faster F1 score implementation. But In fact, the implementation of the f1 score have some constant which are not general, like the range(6)

After fix that, I have tested the eval process code.

def do_eval(sess, textCNN, evalX, evalY, num_classes):
    evalX = evalX[0:3000]
    evalY = evalY[0:3000]
    number_examples = len(evalX)
    eval_loss, eval_counter, eval_f1_score, eval_p, eval_r = 0.0, 0, 0.0, 0.0, 0.0
    batch_size = 1
    predict = []

    for start, end in zip(range(0, number_examples, batch_size), range(batch_size, number_examples + batch_size, batch_size)):
        ''' evaluation in one batch '''
        feed_dict = {textCNN.input_x: evalX[start:end],
                     textCNN.input_y_multilabel: evalY[start:end],
                     textCNN.dropout_keep_prob: 1.0,
                     textCNN.is_training_flag: False}
        current_eval_loss, logits = sess.run(
            [textCNN.loss_val, textCNN.logits], feed_dict)
        predict = [*predict, np.argmax(np.array(logits[0]))]
        eval_loss += current_eval_loss
        eval_counter += 1
    evalY = [np.argmax(ii) for ii in evalY]

    if not FLAGS.multi_label_flag:
        predict = [int(ii > 0.5) for ii in predict]
    _, _, f1_macro, f1_micro, _ = fastF1(predict, evalY, num_classes)
    f1_score = (f1_micro + f1_macro) / 2.0
    return eval_loss / float(eval_counter), f1_score, f1_micro, f1_macro

@jit
def fastF1(result: list, predict: list, num_classes: int):
    ''' f1 score '''
    true_total, r_total, p_total, p, r = 0, 0, 0, 0, 0
    total_list = []
    for trueValue in range(num_classes):
        trueNum, recallNum, precisionNum = 0, 0, 0
        for index, values in enumerate(result):
            if values == trueValue:
                recallNum += 1
                if values == predict[index]:
                    trueNum += 1
            if predict[index] == trueValue:
                precisionNum += 1
        R = trueNum / recallNum if recallNum else 0
        P = trueNum / precisionNum if precisionNum else 0
        true_total += trueNum
        r_total += recallNum
        p_total += precisionNum
        p += P
        r += R
        f1 = (2 * P * R) / (P + R) if (P + R) else 0
        total_list.append([P, R, f1])
    p, r = np.array([p, r]) / num_classes
    micro_r, micro_p = true_total / np.array([r_total, p_total])
    macro_f1 = (2 * p * r) / (p + r) if (p + r) else 0
    micro_f1 = (2 * micro_p * micro_r) / (micro_p + micro_r) if (micro_p + micro_r) else 0
    accuracy = true_total / len(result)
    print('P: {:.2f}%, R: {:.2f}%, Micro_f1: {:.2f}%, Macro_f1: {:.2f}%, Accuracy: {:.2f}'.format(
        p * 100, r * 100, micro_f1 * 100, macro_f1 * 100, accuracy * 100))
    return p, r, macro_f1, micro_f1, total_list