liupei101 / PseMix

[IEEE TMI 2024] Pseudo-Bag Mixup Augmentation for Multiple Instance Learning-Based Whole Slide Image Classification
36 stars 3 forks source link
generalization mixup multiple-instance-learning pseudo-bags robustness whole-slide-images

PseMix: Pseudo-Bag Mixup Augmentation for MIL-Based Whole Slide Image Classification

IEEE Transaction on Medical Imaging, 2024

[Journal Link] | [arXiv] | [PseMix Walkthrough] | [WSI Preprocessing] | [Related Resources] | [Citation]

Abstract: Multiple instance learning (MIL) has become one of the most important frameworks for gigapixel Whole Slide Images (WSIs). In current practice, most MIL networks often face two unavoidable problems in training: i) insufficient WSI data and ii) the sample memorization inclination inherent in neural networks. To address these problems, this paper proposes a new Pseudo-bag Mixup (PseMix) data augmentation scheme, inspired by the basic idea of Mixup. Cooperated by pseudo-bags, this scheme fulfills the critical size alignment and semantic alignment in Mixup. Moreover, it is efficient and plugin-and-play, neither involving time-consuming operations nor relying on model predictions. Experimental results show that PseMix could often improve the performance of state-of-the-art MIL networks. Most importantly, it could also boost the generalization performance of MIL models in special test scenarios, and promote their robustness to patch occlusion and label noise.


📚 Recent updates:

Why PseMix?

Adopting PseMix in training MIL networks could

(1) improve performance with minimal costs:

Network BRCA NSCLC RCC Average AUC
ABMIL 87.05 92.23 97.36 92.21
ABMIL w/ PseMix 89.49 93.01 98.02 93.51
DSMIL 87.73 92.99 97.65 92.79
DSMIL w/ PseMix 89.65 93.92 97.89 93.82
TransMIL 88.83 92.14 97.88 92.95
TransMIL w/ PseMix 90.40 93.47 97.76 93.88

(2) obtain better generalization and robustness:

Training curves, showing the AUC performance on training and test data (exported from wandb), are given below.

(3) often obtain improvements when using the SOTA feature extractor CONCH:

🔥 New SOTA results on TCGA-BRCA, tested with the patch features extracted by CONCH:

Network Prototype AUC ACC F1-Score
ABMIL - 94.48 92.34 93.44
ABMIL w/ PseMix PANTHER 94.77 92.33 94.07
ABMIL w/ PseMix ProDiv 94.63 92.55 94.28
DSMIL - 92.16 89.51 90.91
DSMIL w/ PseMix PANTHER 93.46 91.60 93.46
DSMIL w/ PseMix ProDiv 93.90 91.49 93.63
TransMIL - 93.67 91.92 94.40
TransMIL w/ PseMix PANTHER 94.46 91.81 94.79
TransMIL w/ PseMix ProDiv 94.64 90.65 93.80

There are some fun observations from the table above:

PseMix Walkthrough

PseMix contains two key steps: i) pseudo-bag generation and ii) pseudo-bag mixup. Here are two alternative means that could help you quickly understand the two key steps.

Option 1: Notebooks

Get started with PseMix:

Option 2: Overview and pseudo-codes

An overall description of the two key steps is as follows:

Step 1. Pseudo-bag Generation

Pseudo-bag generation contains two sub-steps:

Its implementation details can be found via the function generate_pseudo_bags. This predefined function will be directly used and called in the next step, pseudo-bag mixup.

Step 2. Pseudo-bag Mixup

Below is the pseudo-code of pseudo-bag mixup:

# generate_pseudo_bags: function for dividing WSI bags into pseudo-bags
# ALPHA: the hyper-parameter of Beta distribution
# N: the number of pseudo-bags in each WSI bag
# PROB_MIXUP: random mixing parameter for determining the proportion of mixed bags. 
for (X, y) in loader: # load a minibatch 
    n_batch = X.shape[0] # with `n_batch` WSI bags (samples)

    # 1. dividing each bag into `N` pseudo-bags
    X = generate_pseudo_bags(X)

    new_idxs = torch.randperm(n_batch)
    # draw a mixing scale from Beta distribution
    lam = numpy.random.beta(ALPHA, ALPHA) 
    lam = min(lam, 1.0 - 1e-5) # avoid numerical overflow when transforming it into discrete ones
    lam_discrete = int(lam * (N + 1)) # transform into discrete values

    # 2. pseudo-bag-level Mixup generates samples (new_X, new_y)
    new_X, new_y = [], []
    for i in range(n_batch):
        # randomly select pseudo-bags according to `lam_discrete`
        masked_bag_A = select_pseudo_bags(X[i], lam_discrete) # select `lam_discrete` pseudo-bags
        masked_bag_B = select_pseudo_bags(X[new_idxs[i]], N - lam_discrete) # select `n-lam_discrete` pseudo-bags

        # random-mixing mechanism for two purposes: more data diversity and efficient learning on mixed samples.
        if np.random.rand() <= PROB_MIXUP:
            mixed_bag = torch.cat([masked_bag_A, masked_bag_B], dim=0) # instance-axis concat
            new_X.append(mixed_bag)
            mix_ratio = lam_discrete / N
        else:
            masked_bag = masked_bag_A 
            new_X.append(masked_bag)
            mix_ratio = 1.0

        # target-level mixing
        new_y.append(mix_ratio * y[i] + (1 - mix_ratio) * y[new_idxs[i]]) 

    # 3. minibatch training
    minibatch_training(new_X, new_y)

More details can be found at

WSI Preprocessing

The procedure of WSI preprocessing is elaborated in Pipeline-Processing-TCGA-Slides-for-MIL. Please move to it for a detailed tutorial.

👩‍💻 Running the Code

Using the following command to load running configurations from a yaml file and train the model:

python3 main.py --config config/cfg_clf_mix.yml --handler clf --multi_run

The configurations that we need to pay attention are as follows:

Other configurations are explained in config/cfg_clf_mix.yml. They could remain as before without any changes.

🔥 Related Resources

Here we list the related works involving pseudo-bags or using pseudo-bags for training deep MIL networks.

Model Subfield Paper Code Base
BDOCOX (TMI'21) WSI Survival Analysis Weakly supervised deep ordinal cox model for survival prediction from wholeslide pathological images - K-means-based pseudo-bag division
DTFD-MIL (CVPR'22) WSI Classification Dtfd-mil: Double-tier feature distillation multiple instance learning for histopathology whole slide image classification Github Random pseudo-bag division
ProDiv (CMPB'24) WSI Classification ProDiv: Prototype-driven Consistent Pseudo-bag Division for Whole-slide Image Classification Github Prototype-based consistent pseudo-bag division
PseMix (TMI'24) WSI Classification Pseudo-Bag Mixup Augmentation for Multiple Instance Learning-Based Whole Slide Image Classification Github Pseudo-bag Mixup
ICMIL (TMI'24) WSI classification Rethinking Multiple Instance Learning for Whole Slide Image Classification: A Bag-Level Classifier is a Good Instance-Level Teacher Github Utilizing pseudo-bags in training
PMIL (TMI'24) WSI classification Shapley Values-enabled Progressive Pseudo Bag Augmentation for Whole Slide Image Classification Github Progressive pseudo-bag augmentation
SWS-MIL (arXiv'24) WSI classification MergeUp-augmented Semi-Weakly Supervised Learning for WSI Classification - Adaptive pseudo bag augmentation

NOTE: please open a new PR if you want to add your work in this resource list.

📝 Citation

If you find this work helps your research, please consider citing our paper:

@article{liu10385148,
  author={Liu, Pei and Ji, Luping and Zhang, Xinyu and Ye, Feng},
  journal={IEEE Transactions on Medical Imaging}, 
  title={Pseudo-Bag Mixup Augmentation for Multiple Instance Learning-Based Whole Slide Image Classification}, 
  year={2024},
  volume={43},
  number={5},
  pages={1841-1852},
  doi={10.1109/TMI.2024.3351213}
}

or P. Liu, L. Ji, X. Zhang and F. Ye, "Pseudo-Bag Mixup Augmentation for Multiple Instance Learning-Based Whole Slide Image Classification," in IEEE Transactions on Medical Imaging, vol. 43, no. 5, pp. 1841-1852, May 2024, doi: 10.1109/TMI.2024.3351213.