matteo-bastico / MI-Seg

Independent Multi-Modal Segmentation
Apache License 2.0
11 stars 1 forks source link

Contributors Forks Stargazers Issues MIT License LinkedIn


MI-Seg

MI-Seg is a framework based on MONAI libray for Cross-Modality clinical images Segmentation using Conditional Models and Interleaved Training.
Explore the docs »

Report Bug · Request Feature

Table of Contents
  1. About The Project
  2. Getting Started
  3. Usage
  4. Roadmap
  5. Contributing
  6. License
  7. Contact
  8. Acknowledgments

Citation

Our paper has been accepted at ICCVW 2023 and is available here and on ArXiv. Please cite our work with

  @InProceedings{Bastico_2023_ICCV,
    author    = {Bastico, Matteo and Ryckelynck, David and Cort\'e, Laurent and Tillier, Yannick and Decenci\`ere, Etienne},
    title     = {A Simple and Robust Framework for Cross-Modality Medical Image Segmentation Applied to Vision Transformers},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
    month     = {October},
    year      = {2023},
    pages     = {4128-4138}
  }

(back to top)

Built With

Our released implementation is tested on:

(back to top)

Getting Started

Prerequisites

(back to top)

Usage

Dataset

The dataset used in our experiments can be downloaded here upon access request. Download and unzip it into /dataset/MM-WHS folder.

[Optional] Convert label and Perform N4 Bias Correction of MRIs using the provided Notebook load_data.ipynb

You should end up with a similar data structure (sub-folders are not represented here)

  MM-WHS
  ├── ct_train  # Ct training folder
  │   ├── ct_train_1001_image.nii.gz # Image
  │   ├── ct_train_1001_label.nii.gz # Label
  │   ...
  ├── ct_test
  ├── mr_train
  ├── mr_test
  ...

The splits we used for our cross_validation are provided in CT_fold1.json and CT_fold2.json.

Training

To train a model you can use the train.py script provided. Single training are based on PyTorch Lightning and all the Trainer arguments can be passed to the script (see here). Additionally, we provide model, data and logger-specific arguments. To have a full list of the possible arguments execute python train.py --help.

An example of C-Swin-UNETR training on single GPU is shown in the following

python train.py --model_name=swin_unetr --out_channels=6 --feature_size=48 --num_heads=3 --accelerator=gpu --devices=1 --max_epochs=2500 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --lr=1e-4 --batch_size=1 --patches_training_sample=1

The available models are unet, unetr and swin_unetr and pre_swin_unetr (in this case the pretrained model of monai must be provided as --pre_swin.

Furthermore, we use WandB to log the experiments and specifications can be set as arguments. In the previous example wandb will run in online mode, so you need to provided login and API key. To change wandb mode set wandb_mode=offline.

Note: AMP (--no_amp) should be disabled with checkpointing to save memory during training of Swin_Unetr based models (--use_checkpoint).

Testing

Our pre-trained models can be downloaded here and tested with the test.py script. The path of the model weights should be provided as --checkpoint (note that the model weight should be under the state_dict key).

Example:

python test.py --out_channels=6 --model_name=swin_unetr --num_workers=2 --feature_size=48  --num_heads=3 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --checkpoint=experiments/<path>

Hyper-parameters Optimization

Hyper-parameters optimization is based on Optuna. For the moment, the script supports automatic setup of distributed tuning ONLY on Slurm environments. Therefore, it needs to be adapted by the user to run in different multi-GPUs enviroments.

The hyper-parameters grid is set in automatic for each model as stated in our paper and the tuning can be started as in the following. The script will run 10 trials, with TPE optimizer and ASHA pruner, and save the in the MI-Seg.log log file (if Slurm) or MI-Seg.sqlite (if not Slurm).

python -u tune.py --num_workers=2 --out_channels=6 --no_include_background --criterion=generalized_dice_focal --scheduler=warmup_cosine --model_name=swin_unetr --n_trials=10 --study_name=c-swin-unetr --max_epochs=2500 --check_val_every_n_epoch=50 --batch_size=1 --patches_training_sample=4 --iters_to_accumulate=4 --cycles=0.5 --storage_name=MI-Seg --min_lr=1e-5 --max_lr=1e-3 --vit_norm_name=instance_cond --encoder_norm_name=instance_cond  --port=23456

The script can be run multiple time with the same --storage_name in order to continue a previous tuning.

To open log files dashboards not stored as RDB, we provide the utils/run_server.py --path=<storage> script. The dashboard of our tuning presented in the paper is available at experiments/optuna/MI-Seg.log and can be open with

python utils/run_server.py --path=experiments/optuna/MI-Seg.log

(back to top)

Pre-Trained Models

The best pre-trained model weights for Conditional UNet and Swin-UNETR resulting from our hyper-parameters optimization can be downloaded here.

For instance, to produce the segmentation on the test dataset using the provided weights you can run for Conditional UNet:

python predict_whs.py --model=unet_vanilla --encoder_norm_name=instance_cond --feature_size 16 64 128 256 512 --num_res_units=3 --strides 1 2 2 2 1 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result

or for Conditional Swin-UNETR:

python -u predict_whs.py --model=swin_unetr --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --feature_size=36 --num_heads=4 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result

Roadmap