OpenImaging / miqa-phase1

A web application for medical imaging quality assurance
MIT License
20 stars 8 forks source link

Improve pixel-based NN classifier #60

Closed dzenanz closed 3 years ago

dzenanz commented 3 years ago

Is Admin one person or multiple people? overall_QA: do multi-class instead of regression? Correlate NN output with IQM metrics Correlate human ratings with IQM metrics Weighted random sampler for PyTorch link Lower learning rate? Other hyper-parameter tuning focal loss? Learning rate decay? Or other LR schedules? Cosine

dzenanz commented 3 years ago

Using full-sized 3D images takes up more GPU memory. Moving 3-fold subset of development data to a more powerful, remote machine.

dzenanz commented 3 years ago

New results:

val_confusion_matrix:
[[ 5  2]
 [55 22]]
              precision    recall  f1-score   support

           0       0.08      0.71      0.15         7
           1       0.92      0.29      0.44        77

    accuracy                           0.32        84
   macro avg       0.50      0.50      0.29        84
weighted avg       0.85      0.32      0.41        84

train_confusion_matrix:
[[12  7]
 [87 38]]
              precision    recall  f1-score   support

           0       0.12      0.63      0.20        19
           1       0.84      0.30      0.45       125

    accuracy                           0.35       144
   macro avg       0.48      0.47      0.33       144
weighted avg       0.75      0.35      0.41       144
aashish24 commented 3 years ago

so I guess we are not doing so well in negative cases. What is 5 2 and 55 22 represents? count of FN TN TP FP?

dzenanz commented 3 years ago

I think:

TN FN
FP TP
dzenanz commented 3 years ago

Using weighted random sampler:

epoch 100 average loss: 0.0744
training confusion matrix:
[[69  0]
 [ 0 66]]

val_confusion_matrix:
[[ 1 10]
 [ 5 77]]
              precision    recall  f1-score   support
           0       0.17      0.09      0.12        11
           1       0.89      0.94      0.91        82
    accuracy                           0.84        93
   macro avg       0.53      0.51      0.51        93
weighted avg       0.80      0.84      0.82        93
current epoch: 100 current accuracy: 0.9355 current AUC: 0.5150 best AUC: 0.5155 at epoch 4
dzenanz commented 3 years ago

The results are better if we use the NN produced by early stopping, e.g.:

Loaded NN model from file '/home/dzenan/miqa/miqa01-val2.pth'
Evaluating NN model on validation data
............................................................
.................................
val_confusion_matrix:
[[10  1]
 [72 10]]
              precision    recall  f1-score   support
           0       0.12      0.91      0.22        11
           1       0.91      0.12      0.22        82
    accuracy                           0.22        93
   macro avg       0.52      0.52      0.22        93
weighted avg       0.82      0.22      0.22        93
Evaluating NN model on training data
............................................................
............................................................
...............
train_confusion_matrix:
[[60  3]
 [66  6]]
              precision    recall  f1-score   support
           0       0.48      0.95      0.63        63
           1       0.67      0.08      0.15        72
    accuracy                           0.49       135
   macro avg       0.57      0.52      0.39       135
weighted avg       0.58      0.49      0.38       135

or

Loaded NN model from file '/home/dzenan/miqa/miqa01-val0.pth'
Evaluating NN model on validation data
............................................................
........................
val_confusion_matrix:
[[ 4  3]
 [ 3 74]]
              precision    recall  f1-score   support
           0       0.57      0.57      0.57         7
           1       0.96      0.96      0.96        77
    accuracy                           0.93        84
   macro avg       0.77      0.77      0.77        84
weighted avg       0.93      0.93      0.93        84
Evaluating NN model on training data
............................................................
............................................................
........................
train_confusion_matrix:
[[92  0]
 [ 0 52]]
              precision    recall  f1-score   support
           0       1.00      1.00      1.00        92
           1       1.00      1.00      1.00        52
    accuracy                           1.00       144
   macro avg       1.00      1.00      1.00       144
weighted avg       1.00      1.00      1.00       144
jeffbaumes commented 3 years ago

This one is looking better, nice!

[[ 4  3]
 [ 3 74]]

👍

What are the differences between miqa01-val2.pth and miqa01-val0.pth? Are they both early stopping or only the second?

dzenanz commented 3 years ago

They are both stopping early. val0 is taking fold 0 for validation, and folds 1 and 2 for training. Big difference between folds probably means that more training data should be helpful. Here we are only using a subset (around 250 out of 10k images).

dzenanz commented 3 years ago

Combining both weighted random sampler and focal loss produces slightly worse results than just weighted random sampler.

Loaded NN model from file '/home/dzenan/miqa/miqa01-val0.pth'

val_confusion_matrix:
[[ 3  4]
 [26 51]]
              precision    recall  f1-score   support

           0       0.10      0.43      0.17         7
           1       0.93      0.66      0.77        77

    accuracy                           0.64        84
   macro avg       0.52      0.55      0.47        84
weighted avg       0.86      0.64      0.72        84

train_confusion_matrix:
[[72  0]
 [ 4 68]]
              precision    recall  f1-score   support

           0       0.95      1.00      0.97        72
           1       1.00      0.94      0.97        72

    accuracy                           0.97       144
   macro avg       0.97      0.97      0.97       144
weighted avg       0.97      0.97      0.97       144

Loaded NN model from file '/home/dzenan/miqa/miqa01-val1.pth'

val_confusion_matrix:
[[ 1  7]
 [ 4 39]]
              precision    recall  f1-score   support

           0       0.20      0.12      0.15         8
           1       0.85      0.91      0.88        43

    accuracy                           0.78        51
   macro avg       0.52      0.52      0.52        51
weighted avg       0.75      0.78      0.76        51

train_confusion_matrix:
[[81  0]
 [ 0 96]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        81
           1       1.00      1.00      1.00        96

    accuracy                           1.00       177
   macro avg       1.00      1.00      1.00       177
weighted avg       1.00      1.00      1.00       177

Loaded NN model from file '/home/dzenan/miqa/miqa01-val2.pth'

val_confusion_matrix:
[[ 4  7]
 [16 66]]
              precision    recall  f1-score   support

           0       0.20      0.36      0.26        11
           1       0.90      0.80      0.85        82

    accuracy                           0.75        93
   macro avg       0.55      0.58      0.55        93
weighted avg       0.82      0.75      0.78        93

train_confusion_matrix:
[[66  0]
 [ 8 61]]
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        66
           1       1.00      0.88      0.94        69

    accuracy                           0.94       135
   macro avg       0.95      0.94      0.94       135
weighted avg       0.95      0.94      0.94       135
dzenanz commented 3 years ago

Second follow-up meeting with the AI team was last Monday.

Status/progress: Tiling the image with minimal overlap, doing average of tile outputs. Classifying into good/bad. Tested with both the small set (80+50+90) and with the large set (5 folds, ~1k images each). 1 fold used for validation, the rest for training. Results for both (below are for the big set) tell the same story: the network performs poorly on unseen bad images.

Applied advice from before: Weighted random sampler for PyTorch link Lower learning rate? Other hyper-parameter tuning focal loss? Learning rate decay? Or other LR schedules? Cosine

Still TODO: Is Admin one person or multiple people? overall_QA: do multi-class instead of regression? Correlate NN output with IQM metrics Correlate human ratings with IQM metrics

Plan: do sanity checks from the TODO section above, then explore some anomaly detection approaches.

Suggestions: Reduce LR exponentially? Predict overall QA from other columns Have multi-class classifier Pretrained net (from MONAI male/female) Smaller network? Many NN, one each for 1 class (redundant with multi-class?) Utilize synthetic bad images (from Qingyu @ Stanford) Ask Rodney LaLonde to help for a few days

dzenanz commented 3 years ago

Closing in favor of separated new issues, e.g. https://github.com/OpenImaging/miqa/issues/30.