cherise215 / Cooperative_Training_and_Latent_Space_Data_Augmentation

[MICCAI 2021 Oral] Cooperative Training and Latent Space Data Augmentation for Robust Medical Image Segmentation
Other
36 stars 6 forks source link
image-segmentation medical-imaging segmentation single-domain-generalization

Domain Generalized, Robust Medical Image Segmentation

This repo contains the pytorch implementation of "Cooperative Training and Latent Space Data Augmentation for Robust Medical Image Segmentation" (MICCAI 2021 Oral). [Video] [Paper] [Poster].

Introduction

We propose a cooperative training framework which consists of a dual-thinking framework and a latent space data augmentation methods for single domain generalization. Unlike existing domain generalization methods, our method does not require multi-domain datasets to learn domain invariant robust features. Our network is capable of self-generating challenging images and segmentations to simulate unseen domain shifts. These hard examples are then used to reinforce the training of our dual-thinking framework for improved cross-domain performance.

Dual thinking framework

Inspired by the two-system model in human behaviour sciences, we design a dual-thinking framework with a fast-thinking network (FTN) for intuitive judgement (image understanding and segmentation) and a slow-thinking network (STN) for shape correction and refinement.

Please see medseg/models/advanced_triplet_recon_segmentation_model.py for detailed network structures. Note, researchers are highly welcomed to adjust the inner structures of encoder and decoders to improve the performance on their datasets.

Latent space data augmentation

We perform latent space masking on the latent image latent code and shape code to get corrupted images and corrupted segmentations, which are then used to train both FTN and STN.

In our paper, we present three ways to perform latent code masking:

Unlike random masking which is image content agnostic, the targeted maskings use the gradients of task-specific losses w.r.t latent code to identify salient features in the latent space to mask. In this way, we can generate challenging samples to benefit the learning of the downstream tasks.

We also provide a jupyter notebook to visualize hard example generation: visualization/vis_hard_example.ipynb.

At training time, we randomly select one type of masking schemes with random probability/threshold p to generate hard samples. Please see medseg/models/advanced_triplet_recon_segmentation_model.py(function: hard_example_generation).

For more details please see our paper on arXiv.

Set Up

Data

Usage

Citation

If you find this useful for your work, please consider citing

@INPROCEEDINGS{Chen_MICCAI_2021_Cooperative,
  title     = "Cooperative Training and Latent Space Data Augmentation for Robust Medical Image Segmentation",
  booktitle = "Medical Image Computing and Computer Assisted Intervention --
               {MICCAI} 2021",
  author    = {Chen Chen and
               Kerstin Hammernik and
               Cheng Ouyang and
               Chen Qin and
               Wenjia Bai and
               Daniel Rueckert},
  publisher = "Springer International Publishing",
  year      =  2021
}