imics-lab / biodiffusion

MIT License
18 stars 2 forks source link
# BioDiffusion: A Versatile Diffusion Model for Biomedical Signal Synthesis
## Description This repository serves as the primary codebase for implementing a Diffusion model, specifically designed for the generation of synthetic signals. The model is a pivotal component used in the research paper titled "BioDiffusion: A Versatile Diffusion Model for Biomedical Signal Synthesis" ([accessible here](https://arxiv.org/abs/2401.10282)). ## Setup Before running the code, ensure that you have the following prerequisites installed: - Python 3.x - PyTorch - Nvidia CUDA toolkit and cuDNN (for GPU acceleration) ```bash pip install torch torchvision conda install cudatoolkit ``` ### Conda Virtual Environment Create the Conda virtual environment using the [environment file](environment.yml): ```bash conda env create -f environment.yml # dynamically set python path for the environment conda activate BioDiffusion conda env config vars set PYTHONPATH=$(pwd):$(pwd)/src ``` ## Training In order to use our model first you will need to retrain it: ### Unconditional diffusion model for 3 channel data: ```python ddpm.py``` ### Conditional diffusion model for 3 channel data: ```python ddpm_conditional.py``` ### Classifier-free diffusion model for 1 channel data: ```cd signal/ python load_dataset.py #KaggleKey #KaggleName python ddpm1d_cls_free.py ``` ### Signal conditional diffusion model for 1 channel data: ```cd signal/ python load_dataset.py #KaggleKey #KaggleName python ddpm1d_sign_cond.py ``` ## Sampling After [training](#training), you can sample from the trained model using the following steps: ### Unconditional Diffusion Model for 3 Channel Data: ```python # Set device and load the pre-trained UNet model device = "cuda:2" model = UNet().to(device) ckpt = torch.load("../../src/models/DDPM_Unconditional/ckpt.pt") model.load_state_dict(ckpt) # Create a Diffusion model instance and sample from it diffusion = Diffusion(img_size=32, device=device) x = diffusion.sample(model, 10) ``` ### Conditional Diffusion Model for 3 Channel Data: ```python # Set the number of samples and device n = 10 device = "cuda:3" # Create a Diffusion model instance and load the trained model checkpoint diffusion = Diffusion(img_size=32, device=device) diffusion.load("../../src/models/DDPM_conditional") # Prepare labels and sample from the diffusion model labels = torch.full((n,), 1).long().to(diffusion.device) sampled_images = diffusion.sample(use_ema=False, labels=labels) ``` ### Classifier-Free Diffusion Model for 1 Channel Data: ```python # Set the number of samples, device, and create the Unet1D_cls_free model n = 10 device = "cuda:3" model = Unet1D_cls_free( dim=64, dim_mults=(1, 2, 4, 8), num_classes=5, cond_drop_prob=0.5, channels=1 ) # Load the pre-trained model checkpoint ckpt = torch.load("../../src/signal/checkpoint/DDPM1D_cls_free_MITBIH/checkpoint.pt") model.load_state_dict(ckpt['model_state_dict']) model = model.to(device) # Create the GaussianDiffusion1D_cls_free model and sample from it diffusion = GaussianDiffusion1D_cls_free( model, seq_length=128, timesteps=1000 ).to(device) y = torch.Tensor([0] * n).long().to(device) x = diffusion.sample(classes=y, cond_scale=3.) ``` ### Self-Conditional Diffusion Model for 1 Channel Data: ```python # Set the device and create the Unet1D model with self-conditioning device = "cuda:3" model = Unet1D( dim=64, self_condition=True, dim_mults=(1, 2, 4, 8), channels=1 ) # Load the pre-trained model checkpoint ckpt = torch.load("../../src/signal/checkpoint/DDPM1D_Selfconditional_maskedCond/checkpoint.pt") model.load_state_dict(ckpt['model_state_dict']) model = model.to(device) # Create the GaussianDiffusion1D model and sample from it # seq_length must be able to be divided by dim_mults diffusion = GaussianDiffusion1D( model, seq_length=128, timesteps=1000, objective='pred_v' ).to(device) ``` #### Make sure to adjust the file paths and model names as needed. ## Contributing Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. ## Citation Feel free to cite our paper using this .bibtex or .cff formats in this repository. ```bibtex @misc{li2024biodiffusion, title={BioDiffusion: A Versatile Diffusion Model for Biomedical Signal Synthesis}, author={Xiaomin Li and Mykhailo Sakevych and Gentry Atkinson and Vangelis Metsis}, year={2024}, eprint={2401.10282}, archivePrefix={arXiv}, primaryClass={eess.SP} } ```