Closed irfanICMLL closed 8 months ago
Thanks for sharing this amazing work. I wonder what is the 'init.pth' for the dsmil model. Is it trained on some specific dataset?
Do I need to load it if I am doing other tasks (not the medical ones)?
Best wishes, Yifan
Hi Yifan,
init.pth
is a set of good initialization weights especially useful for highly unbalanced bags (i.e., small portions of positive instances in positive bags), without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset following the original training/testing split. I would suggest loading this set of weights for other tasks if possible to facilitate the convergence.
Hi,
Thanks for sharing your code. I wanted to concatenate two sets of features that I extracted from my dataset (both at the same magnification) using compute_feats.py, then use the new sets of features to train an aggregator (train_tcga.py). I will have a new set of 1024 features this way, so I cannot use init.pth to initialize dsmil. Do you have any suggestions on how to initialize dsmil?
Also, do we need to change the criterion function for multi-class datasets from BCEWithLogitsLoss to something else?
Thanks!
Hi,
Thanks for sharing your code. I wanted to concatenate two sets of features that I extracted from my dataset (both at the same magnification) using compute_feats.py, then use the new sets of features to train an aggregator (train_tcga.py). I will have a new set of 1024 features this way, so I cannot use init.pth to initialize dsmil. Do you have any suggestions on how to initialize dsmil?
Also, do we need to change the criterion function for multi-class datasets from BCEWithLogitsLoss to something else?
Thanks!
The model should be able to train even without the initialization weights. For sanity check, you could try to load ResNet54 pretrained weights (or the corresponding ResNet that outputs 1024-vector) and to see if the model converges. BCE loss handles multi-class problems and I think the difference to CrossEntropy loss is that it treats each class individually as a binary classification problem and does not assume the classes' probabilities summed to one as if there were drown from a joint distribution. But to use BCE loss the class labels need to be a multi-digit binary vector while for CrossEntropy loss the label is a single value that is the index of the class of one-hot coding. Also, BCE loss handles multi-label problems, for example, some slides can be positive for multiple classes, the labels will then be [0, 1, 1, 0, ...]
with 1
flagging the class that this slide belongs to.
Thank you so much for your thorough explanation!
without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset
Hi, thanks for the great work, it seems this 'init.pth' is important for the performance, however, what if I use another backbone, for example using ResNet50 instead of ResNet18, how should I get this 'init.pth' then? Can you provide a training script for obtaining this 'init.pth'?
without it, the training can sometimes take a very long time to even start to converge on the Camleyon16 dataset. It is trained with a few interactions on the Camelyon16 dataset
Hi, thanks for the great work, it seems this 'init.pth' is important for the performance, however, what if I use another backbone, for example using ResNet50 instead of ResNet18, how should I get this 'init.pth' then? Can you provide a training script for obtaining this 'init.pth'?
It is just a set of good initialization weights. You can try out a few initialization methods in PyTorch https://pytorch.org/docs/stable/nn.init.html
Hi @binli123, could you please let me know which things I should initialize using various initialization methods in PyTorch? Additionally, do you have an example? I'm currently working on implementing DSMIL without 'init.pth' because of the varying number of extracted patch dimensions (not 512). However, I've noticed that the performance differs significantly across different seeds
Hi @binli123, could you please let me know which things I should initialize using various initialization methods in PyTorch? Additionally, do you have an example? I'm currently working on implementing DSMIL without 'init.pth' because of the varying number of extracted patch dimensions (not 512). However, I've noticed that the performance differs significantly across different seeds
I incorporated the training/testing into the same pipeline in the latest commit. This change allows you to read the evaluation results on a reserved test set. I also incorporated a simple weights initialization method which helps stabilize the training. You can set --eval_scheme=5-fold-cv-standalone-test which will perform a train/valid/test like this:
A standalone test set consisting of 20% samples is reserved, remaining 80% of samples are used to construct a 5-fold cross-validation. For each fold, the best model and corresponding threshold are saved. After the 5-fold cross-validation, 5 best models along with the corresponding optimal thresholds are obtained which are used to perform inference on the reserved test set. A final prediction for a test sample is the majority vote of the 5 models. For a binary classification, accuracy and balanced accuracy scores are computed. For a multi-label classification, hamming loss (smaller the better) and subset accuracy are computed.
You can also simply run a 5-fold cv --eval_scheme=5-fold-cv
There were some issues with the testing script when loading pretrained weights (i.e., sometimes the weights are not fully loaded or there are missing weights, setting strict=False can reveal the problems.). The purpose of the testing script is to generate the heatmap, you should now read the performance directly from the training script. I will fix the issues in a couple of days.
The new initialization method seems more stable and there is no need for "init.pth"
Thanks for sharing this amazing work. I wonder what is the 'init.pth' for the dsmil model. Is it trained on some specific dataset?
Do I need to load it if I am doing other tasks (not the medical ones)?
Best wishes, Yifan