zaandahl / mewc-train

Docker implementation of TensorFlow and EfficientNet v2 for training wildlife camera trap classifiers
BSD 3-Clause "New" or "Revised" License
3 stars 2 forks source link

Add class-specific metric calculation at the end of training #1

Closed PetervanLunteren closed 10 months ago

PetervanLunteren commented 10 months ago

For training species classifiers it is often important to know how well the model performs on certain species, as opposed to only knowing the overall accuracy. This PR includes an new function into lib_model.py that is called from mewc_train.py and evaluates each class and prints a classification report like the one below.

Classification report:
                       precision    recall  f1-score   support

     african wild cat       0.94      0.87      0.90      3795
               baboon       0.79      0.86      0.82      2195
                 bird       0.83      0.94      0.88      4714
         brown hyaena       0.53      0.92      0.67       804
                canid       0.71      0.66      0.68       830
              caracal       0.16      0.77      0.27        44
               cattle       1.00      0.90      0.94     29470
              cheetah       0.76      0.77      0.77       911
             elephant       0.92      0.88      0.90     12575
              gemsbok       0.91      0.91      0.91      5555
              giraffe       0.93      0.90      0.91      6589
                 hare       0.94      0.87      0.90      3422
                hyrax       0.19      1.00      0.32        21
klipspringer+steenbok       0.26      0.85      0.39       158
                 kudu       0.50      0.93      0.65       565
              leopard       0.34      0.77      0.47       161
                 lion       0.87      0.71      0.78      3472
             mongoose       0.48      0.85      0.62       238
              ostrich       0.89      0.91      0.90      2291
                other       0.48      0.87      0.61       819
            porcupine       0.60      0.95      0.73       203
           rhinoceros       0.20      0.95      0.33       330
       spotted hyaena       0.86      0.75      0.80      2953
            springbok       0.81      0.88      0.84      1780
                zebra       0.87      0.90      0.89      3713

             accuracy                           0.88     87608
            macro avg       0.67      0.86      0.72     87608
         weighted avg       0.91      0.88      0.89     87608

It additionally saves a visual representation of a normalised confusion matrix. This provides insight into misclassifications and can help ecologists to group certain species together. It will adjust font size according to the number of classes present, and with few classes it will also print its cell value. confusion_matrix 2

zaandahl commented 10 months ago

I'm waiting to get a quick review from @BWBrook on the pull request, but the code changes look good to me. Will merge this pull request once it gets another thumbs up. :)

zaandahl commented 10 months ago

I'm going to merge this pull request, I think Barry is out for a little bit longer. If there are problems we can always roll back.