JuliaWolleb / Diffusion-based-Segmentation

This is the official Pytorch implementation of the paper "Diffusion Models for Implicit Image Segmentation Ensembles".
MIT License
280 stars 38 forks source link

Train with multiple classes #44

Closed suhas-srinath closed 1 year ago

suhas-srinath commented 1 year ago

Hi, thanks a lot for the code! I am interested in training with multiple segmentation classes, but some parts of the code convert the masks into binary masks. How does the training happen with multiple segmentation classes and what are the right parameter arguments for the same?

JuliaWolleb commented 1 year ago

Hi If you have multiple classes, you will need a one-hot-encoding for the classes. For example, if you have 3 classes, you will have 3 output channels, each showing a binary mask for the respective class. To change the number of output channels (out_channels) in the file guided_diffusion/script_util.py to 2*3=6. If you set the flag learn_sigma=False, set out_channels=3.

Let me know if you have further questions

suhas-srinath commented 1 year ago

Hi, thanks a lot for the response! I'm able to train it for 3 classes now.