LiuTingWed / CriDiff

The official implement of paper CriDiff: Criss-cross Injection Diffusion Framework via Generative Pre-train for Prostate Segmentation.
7 stars 0 forks source link

CriDiff

The official implement of MICCAI 2024 paper CriDiff: Criss-cross Injection Diffusion Framework via Generative Pre-train for Prostate Segmentation. Structure Figure

Environment Installation

conda create -n CriDiff python=3.8 -y
conda activate CriDiff
git clone https://github.com/LiuTingWed/CriDiff.git
cd CriDiff
pip install -r requirements.txt

Datasets Preparation

Download Datasets

4 datasets need download (NCI-ISBI, ProstateX, Promise12, CCH-TRUSPS) from: \ Google Driver | Baidu Driver (6666) \ I'm not sure about the copyright status of these datasets. If you are the owner of these datasets, please submit an issue to let me know so that I can remove them accordingly.

Check data branch like this:

Data_branch \ The body and detail are generated by extract_boundary/generate_body_detail.py. \ Please check this .py for more details.

Download Pre-train Weight

Google Driver (PVT_b2)

Training & Inference & Evaluation

Generative pretrain

This stage relies on accelerate, please install it and set it up. \ python generative_pretrain/train_generator_accelerate.py --dataset_root xxx/DATASET_NAME/images/train

Training

Before training, please check --dataset_root, --cp_condition_net, --cp_stage1, --checkpoint_save_dir in train.py \ python -m torch.distributed.launch --nproc_per_node=2 train.py

Why can't the model perform training and validation simultaneously?

The output of diffusion models is related to the randomly sampled noise: different noise leads to different outputs. I have not addressed the issue of fluctuating model performance between the training and validation stages, for detailed descriptions please refer to this link. Therefore, I would recommend saving all checkpoints, and then using two separate GPUs for validation to ensure that others can also achieve consistent performance. Well, I hope someone smarter than me tell me why :-).

Inference

After training, in path --checkpoint_save_dir/job_name will have many .pth file. \ Check --loadDir, --loadDer_cp and --dataset_root in infer_allCp_xxxx.py and run it.

Evaluation

The prediction of CriDiff is this link, run eval_dice_iou_hd95_asd/eval.py to eval it.

Thanks

This repository refer to med-seg-diff-pytorch and denoising-diffusion-pytorch. Some very concise diffusion frameworks are helpful to me.

Citation

@article{liu2024cridiff,
  title={CriDiff: Criss-cross Injection Diffusion Framework via Generative Pre-train for Prostate Segmentation},
  author={Liu, Tingwei and Zhang, Miao and Liu, Leiye and Zhong, Jialong and Wang, Shuyao and Piao, Yongri and Lu, Huchuan},
  journal={arXiv preprint arXiv:2406.14186},
  year={2024}
}

Any questions please contact with tingweiliu@mail.dlut.edu.cn