lin-tianyu / Stable-Diffusion-Seg

[MICCAI 2024] Codebase for "Stable Diffusion Segmentation for Biomedical Images with Single-step Reverse Process"
https://lin-tianyu.github.io/Stable-Diffusion-Seg/
Other
60 stars 6 forks source link

Multi-Segmentation #16

Open chloielam opened 1 month ago

chloielam commented 1 month ago

May I ask why, in the multi-segmentation dataset, we assign the class_id to random existing classes?

lin-tianyu commented 1 month ago

Hi @chloielam! Glad you ask.

SDSeg is a binary-segmentation model for now. The released multi-class segmentation code is still undergoing experimental refinement, thus not published with the SDSeg paper.

And for your question, in the SDSeg setting, the AutoEncoder expects a 3-channel input of an image to get the corresponding latent representation, and we don't want to make any changes to the autoencoder. That's for one.

Secondly, with a fixed 3-channel input, we have tried inputting the multi-label map with values 0, 1, 2, 3, etc. (and copy 3 times for input), but it doesn't work well.

As a result, the multi-class segmentation code that you're using now takes the segmentation map of random existing classes as the AutoEncoder's input, with also a label embedding telling the model which classes are being processed. With this approach, the multi-class version of SDSeg can have reasonable segmentation results (though not SOTA yet).

As for "existing" classes, we don't really want the model to sample too many empty classes which have no benefit for the model and may lead to a strong class-imbalance problem.

All in all, at each training step we sample some existing classes to train the model for segmenting certain classes with given label embedding. And in the inference stage, we will send all the label embeddings one by one to generate segmentation results for all classes of an image.


Anyway, this is only a non-optimal solution for expanding SDSeg to multi-class segmentation. I am still not quite sure why this cannot work well. My biggest guess is that the BTCV dataset is too small.

Feel free to try the multi-class segmentation code of SDSeg! If you come up with a better idea of doing multi-class segmentation, please don't hesitate to contact me. Let's see if we can build an SDSeg-V2 or something together, lol

Best, Tianyu