clemsgrs / hipt

Re-implementation of HIPT
18 stars 6 forks source link

Hierarchical Image Pyramid Transformer

Re-implementation of original HIPT code.

empty empty

Requirements

Prerequisite

You need to have extracted square regions from each WSI you intend to train on.
To do so, you can take a look at HS2P, which segments tissue and extract relevant patches at a given pixel spacing.

Download HIPT pre-trained weights via the following commands:

Download commands ``` mkdir checkpoints cd checkpoints gdown 1Qm-_XrTMYhu9Hl-4FClaOMuroyWlOAxw gdown 1A2eHTT0dedHgdCvy6t3d9HwluF8p5yjz ```

Feature Extraction

1. [Optional] Configure wandb

If you want to benefit from wandb logging, you need to follow these simple steps:

2. Extract features

Your folder containing the extracted square regions should be structured as follow:

Folder structure ```bash region_dir/ ├── slide_1/ │ ├── slide_1.h5 │ └── imgs/ │ ├── region_1.fmt │ ├── region_2.fmt │ └── ... ├── slide_2/ ├── slide_3/ └── ... ```

Create a configuration file under config/feature_extraction/ taking inspiration from existing files.
A good starting point is to use the default configuration file config/default.yaml where parameters are documented.

Then run the following command to kick off feature extraction:

python3 extract_features.py --config-name <feature_extraction_config_filename>

This will produce one .pt file per slide and save it under output/<experiment_name>/<experiment_id>/slide.
If you set save_region_features to True in your config file, it will also produce one .pt file per region and save it under output/<experiment_name>/<experiment_id>/region. In the end, the output folder should look like:

hipt/
├── output/<experiment_name>/<experiment_id>/features/
│     ├── slide/
│     │    ├── slide_1.pt
│     │    ├── slide_2.pt
│     │    └── ...
│     └── region/
│          ├── slide_1_x1_y1.pt
│          ├── slide_1_x2_y2.pt
│          └── ...

If you encounter OOM errors while extracting region-level features, it could be that the model is too big to fit on your gpu. In that case, work with 2 gpus and enable model split across gpus setting split_across_gpus: True in your config file.

Distributed feature extraction across multiple gpus is supported via:

python3 -m torch.distributed.run --standalone --nproc_per_node=gpu extract_features.py --config-name <feature_extraction_config_filename>

HIPT Training

1. Prepare train.csv and tune.csv

For this pipeline you will need two csv files: train.csv and tune.csv.
The syntax is easy:

slide_id,label
TRAIN_1,1
TRAIN_2,0
...

If you want to run testing at the end, you can provide a test.csv file.

2. Train a single-fold model on extracted features

Once features have been extracted, create a configuration file under:

You can take inspiration from single.yaml files.
Dump in there the paths to your train.csv and tune.csv files.
If you want to run testing as well, add the path to your test.csv file. Otherwise, leave it blank (it'll skip testing).

Then, run the following command to kick off model training on a single fold:

Once again, if you encounter OOM errors here, it could be that the model is too big to fit on your gpu. In that case, work with 2 gpus and enable model split across gpus setting split_across_gpus: True in your config file.

3. Train a multi-fold model on extracted features

Your multiple folds should be structured as follow:

Folds structure ```bash fold_dir/ ├── fold_1/ │ ├── train.csv │ ├── tune.csv │ └── test.csv ├── fold_2/ └── ... ```

Create a configuration file under:

You can take inspiration from multi.yaml files.
Remember to indicate the root directory where your folds are located under data.fold_dir.

Then, run the following command to kick off model training on multiple folds:

Interpretability via Attention Heatmaps

Given HIPT consists in 3 Transformers stacked on one another, one can look at attention heatmaps for each Transformer independantly. Additionally, one can also blend these attention heatmaps together and generate factorized attention heatmaps.

To generate attention heatmaps, create a configuration file under config/heatmaps.
You can take inspiration from default.yaml, where all entries should be documented.

Then, run the following command to kick off attention heatmap generation:

python3 attention_visualization.py --config-name <heatmaps_config_filename>

Depending on which part of the code is not commented out, it will produce many .png files per slide and save them under output/<experiment_name>/<experiment_id>:

Heatmap output structure ```bash hipt/ ├── output/// │ ├── slide_1/ │ │ ├── patch.png │ │ ├── region.png │ │ ├── patch// │ │ │ ├── head_0/ │ │ │ │ ├── x1_y1.png │ │ │ │ ├── x2_y2.png │ │ │ │ └── ... │ │ │ ├── head_1/ │ │ │ └── ... │ │ ├── region// │ │ │ ├── head_0/ │ │ │ │ ├── x1_y1.png │ │ │ │ ├── x2_y2.png │ │ │ │ └── ... │ │ │ ├── head_1/ │ │ │ └── ... │ │ ├── hierarchical__/ │ │ │ ├── rhead_0_phead_0/ │ │ │ │ ├── x1_y1.png │ │ │ │ ├── x2_y2.png │ │ │ │ └── ... │ │ │ ├── rhead_0_phead_1/ │ │ │ └── ... │ │ └── slide/ │ │ ├── patch_head_0.png │ │ ├── region_head_0.png │ │ ├── rhead_0_phead_0.png │ │ └── ... │ ├── slide_2/ │ │ └── ... │ ├── slide_3/ │ └── ... ```

where, for each slide (and for each attention head):

python3 -m torch.distributed.run --standalone --nproc_per_node=gpu attention_visualization.py --config-name <heatmaps_config_filename>

Hierarchical Pretraining

Example Directory ```bash PRETRAINING_DIR/ └──patch_256_pretraining/ └──imgs/ ├── patch_1.png ├── patch_2.png └── ... └──region_4096_pretraining/ ├── slide_1_region_1.pt ├── slide_1_region_2.pt └── ... ```

Where:

NB: you should be able to user differently sized regions (e.g. [1024 × 1024]) seamlessly.

Create two configuration files under config/pre-training/ (one for each pre-training stage).
You can take inspiration from existing files.

The following commands are used for single-gpu pretraining:

python3 pretrain/dino_patch.py --config-name <dino_patch_config>
python3 pretrain/dino_region.py --config-name <dino_region_config>

Distributed pretraining across multiple gpus is supported via:

python3 -m torch.distributed.run --nproc_per_node=gpu pretrain/dino_patch.py --config-name <dino_patch_config>
python3 -m torch.distributed.run --nproc_per_node=gpu pretrain/dino_region.py --config-name <dino_region_config>

Early Stopping during DINO pretraining

When pretraining the patch-level Transformer, we can monitor student & teacher performances via a downstream patch-level classification task. This can later be used as an early stopping mechanism to prevent model overtraining.

For more details about this, check DINO paper at https://arxiv.org/abs/2104.14294.

To enable early stopping, make sure to check the parameters under early_stopping in your configuration file:

parameter type description
tune_every int run downstream evaluation every {tune_every} epochs (leave blank if you want disable downstream evaluation)
tracking str which metric should be tracked for early stopping
min_max str whether the previous metric should be maximised ('max') or minimized ('min')
patience int early stopping patience (in number of epochs)
min_epoch int minimum number of epochs to complete before early stopping can be triggered
save_every int save checkpoint every {save_every} epochs
downstream.train_csv str path to the downstream train .csv file (it should have the following columns: tile_path : path to patch image, label : patch label)
downstream.test_csv str path to the downstream test .csv file (same requirements)
downstream.label_name str column name holding the labels in the previous .csv files
knn.k int number of neighbours to retrieve
knn.temperature int temperature value used when computing kNN probabilities
knn.save_features bool whether to save kNN features to disk