binli123 / dsmil-wsi

DSMIL: Dual-stream multiple instance learning networks for tumor detection in Whole Slide Image
MIT License
378 stars 88 forks source link

What is the 'init.pth' #26

Closed irfanICMLL closed 8 months ago

irfanICMLL commented 3 years ago

Thanks for sharing this amazing work. I wonder what is the 'init.pth' for the dsmil model. Is it trained on some specific dataset?

Do I need to load it if I am doing other tasks (not the medical ones)?

Best wishes, Yifan

binli123 commented 3 years ago

Thanks for sharing this amazing work. I wonder what is the 'init.pth' for the dsmil model. Is it trained on some specific dataset?

Do I need to load it if I am doing other tasks (not the medical ones)?

Best wishes, Yifan

Hi Yifan, init.pth is a set of good initialization weights especially useful for highly unbalanced bags (i.e., small portions of positive instances in positive bags), without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset following the original training/testing split. I would suggest loading this set of weights for other tasks if possible to facilitate the convergence.

shimaxy commented 2 years ago

Hi,

Thanks for sharing your code. I wanted to concatenate two sets of features that I extracted from my dataset (both at the same magnification) using compute_feats.py, then use the new sets of features to train an aggregator (train_tcga.py). I will have a new set of 1024 features this way, so I cannot use init.pth to initialize dsmil. Do you have any suggestions on how to initialize dsmil?

Also, do we need to change the criterion function for multi-class datasets from BCEWithLogitsLoss to something else?

Thanks!

binli123 commented 2 years ago

Hi,

Thanks for sharing your code. I wanted to concatenate two sets of features that I extracted from my dataset (both at the same magnification) using compute_feats.py, then use the new sets of features to train an aggregator (train_tcga.py). I will have a new set of 1024 features this way, so I cannot use init.pth to initialize dsmil. Do you have any suggestions on how to initialize dsmil?

Also, do we need to change the criterion function for multi-class datasets from BCEWithLogitsLoss to something else?

Thanks!

The model should be able to train even without the initialization weights. For sanity check, you could try to load ResNet54 pretrained weights (or the corresponding ResNet that outputs 1024-vector) and to see if the model converges. BCE loss handles multi-class problems and I think the difference to CrossEntropy loss is that it treats each class individually as a binary classification problem and does not assume the classes' probabilities summed to one as if there were drown from a joint distribution. But to use BCE loss the class labels need to be a multi-digit binary vector while for CrossEntropy loss the label is a single value that is the index of the class of one-hot coding. Also, BCE loss handles multi-label problems, for example, some slides can be positive for multiple classes, the labels will then be [0, 1, 1, 0, ...] with 1 flagging the class that this slide belongs to.

shimaxy commented 2 years ago

Thank you so much for your thorough explanation!

thomascong121 commented 1 year ago

without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset

Hi, thanks for the great work, it seems this 'init.pth' is important for the performance, however, what if I use another backbone, for example using ResNet50 instead of ResNet18, how should I get this 'init.pth' then? Can you provide a training script for obtaining this 'init.pth'?

binli123 commented 1 year ago

without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset

Hi, thanks for the great work, it seems this 'init.pth' is important for the performance, however, what if I use another backbone, for example using ResNet50 instead of ResNet18, how should I get this 'init.pth' then? Can you provide a training script for obtaining this 'init.pth'?

It is just a set of good initialization weights. You can try out a few initialization methods in PyTorch https://pytorch.org/docs/stable/nn.init.html

bryanwong17 commented 10 months ago

Hi @binli123, could you please let me know which things I should initialize using various initialization methods in PyTorch? Additionally, do you have an example? I'm currently working on implementing DSMIL without 'init.pth' because of the varying number of extracted patch dimensions (not 512). However, I've noticed that the performance differs significantly across different seeds

binli123 commented 8 months ago

Hi @binli123, could you please let me know which things I should initialize using various initialization methods in PyTorch? Additionally, do you have an example? I'm currently working on implementing DSMIL without 'init.pth' because of the varying number of extracted patch dimensions (not 512). However, I've noticed that the performance differs significantly across different seeds

I incorporated the training/testing into the same pipeline in the latest commit. This change allows you to read the evaluation results on a reserved test set. I also incorporated a simple weights initialization method which helps stabilize the training. You can set --eval_scheme=5-fold-cv-standalone-test which will perform a train/valid/test like this:

A standalone test set consisting of 20% samples is reserved, remaining 80% of samples are used to construct a 5-fold cross-validation. For each fold, the best model and corresponding threshold are saved. After the 5-fold cross-validation, 5 best models along with the corresponding optimal thresholds are obtained which are used to perform inference on the reserved test set. A final prediction for a test sample is the majority vote of the 5 models. For a binary classification, accuracy and balanced accuracy scores are computed. For a multi-label classification, hamming loss (smaller the better) and subset accuracy are computed.

You can also simply run a 5-fold cv --eval_scheme=5-fold-cv

There were some issues with the testing script when loading pretrained weights (i.e., sometimes the weights are not fully loaded or there are missing weights, setting strict=False can reveal the problems.). The purpose of the testing script is to generate the heatmap, you should now read the performance directly from the training script. I will fix the issues in a couple of days.

binli123 commented 8 months ago

The new initialization method seems more stable and there is no need for "init.pth"