binli123 / dsmil-wsi

DSMIL: Dual-stream multiple instance learning networks for tumor detection in Whole Slide Image
MIT License
358 stars 88 forks source link

About test/validation set for the TCGA dataset #55

Closed wwyi1828 closed 6 months ago

wwyi1828 commented 2 years ago

Thanks for the great work!

You mentioned that you randomly split 840 training slides and 210 testing slides in the paper. In the train_tcga.py, I noticed that all the slides are used for training and testing. The best model for the test set will be saved. The function of the test set here is more like a validation set. May I ask how you select the best model and what data you test your model on? Thanks!

binli123 commented 2 years ago

You can perform cross-validation on the training set, and find the best threshold by considering all folds.

Vison307 commented 1 year ago

@binli123 Thank you for your excellent work!

I come up with the same problem. Did the test results of DSMIL on the TCGA dataset in Table 2 of the paper come from the best scored model's performance (L203 of train_tcga.py) on the test_path dataset (L182 of train_tcga.py)? Or they came from a cross validation by calculating the average of all validation splits (or reduced by other operations)?

binli123 commented 6 months ago

@Vison307 I incorporated the training/testing into the same pipeline in the latest commit. I also incorporated an orthogonal weights initialization which helps making the training more table. You can set --eval_scheme=5-fold-cv-standalone-test which will perform a train/valid/test like this:

A standalone test set consisting of 20% samples is reserved, remaining 80% of samples are used to construct a 5-fold cross-validation. For each fold, the best model and corresponding threshold are saved. After the 5-fold cross-validation, 5 best models along with the corresponding optimal thresholds are obtained which are used to perform inference on the reserved test set. A final prediction for a test sample is the majority vote of the 5 models. For a binary classification, accuracy and balanced accuracy scores are computed. For a multi-label classification, hamming loss (smaller the better) and subset accuracy are computed.

You can also simply run a 5-fold cv --eval_scheme=5-fold-cv

There were some issues with the testing script when loading pretrained weights (i.e., sometimes the weights are not fully loaded or there are missing weights, setting strict=False can reveal the problems.). The purpose of the testing script is to generate the heatmap, you should now read the performance directly from the training script. I will fix the issues in a couple of days.