utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

Accuracy by class – confusion_matrix_by_class() #240

Open matteozullo opened 4 years ago

matteozullo commented 4 years ago

Hi all,

I need to calculate accuracy by class in a multi-class classification problem, and I am using @Pawel-Kranzberg's confusion_matrix_by_class() function. How should I interpret the resulting confusion matrices? In my toy example, I am testing BERT on a dataset with n=30 and I have three categories resulting in the following matrices:

array([[[ 0, 14],
        [ 0, 16]],

       [[24,  0],
        [ 6,  0]],

       [[22,  0],
        [ 8,  0]]])

I could not find any guidance for interpretation in prior issues.

Thank you

Pawel-Kranzberg commented 4 years ago

Hi @matteozullo. the function is based on the scikit learn implementation of multilabel confusion matrices. So - paraphrasing the relevant docs - each of your 2 x 2 matrices represent one of your three categories "as if binarized under a one-vs-rest transformation". In each matrix the "i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class", where class 1 is your given category and class 0 is the balance. Which basically means:

[[ True Negatives, False Positives], 
 [ False Negatives, True Positives]]

Yes, it differs from examples at https://en.wikipedia.org/wiki/Confusion_matrix

E.g.:

BTW, the small sample notwithstanding, you could rebalance your categories to try and obtain better results.