Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
555 stars 78 forks source link

Class conditioned diffusion #486

Open talg2324 opened 2 months ago

talg2324 commented 2 months ago

Hi! Thanks for the great work here.

When working with class-conditioned diffusion, the available inferers don't offer an input of class_label, only a concat or crossatn based context. However the diffusion model has all the infrastructure needed and seems to handle class conditioning well.

I am wondering if this is intentional and I am misusing the class label or if there is a plan to add class_label inputs to the inferer classes. Seems like a very minimal amount of code change if I understood correctly.

I did find this: https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/classifier_free_guidance/2d_ddpm_classifier_free_guidance_tutorial.ipynb

But I don't understand why I should use the context input instead of the class_label which seems more relevant.. did I misunderstand something?

ericspod commented 2 months ago

@marksgraham or @virginiafdez may have further insight here, but it's possibly an oversight in the DiffusionInferer class we can rectify with the integration into core.

marksgraham commented 2 months ago

Yeah I think we should be supporting input of a class_label in infererers. I suspect the reason we don't have it is that we found conditioning tends to work best through context, but we should support it and let users decide for themselves. I'll try to add it in on the refactor.

@talg2324 we're moving MONAI Generative into MONAI core you'll have to wait for that port to be complete and the new feature will be available there

Ahmad-Omar-Ahsan commented 1 month ago

I have a question regarding the scheduler used for class conditional sampling.

So I noticed in the tutorial that DDPM was used, can we use DDIM instead to generate the samples? If I use the same noise vector and labels [0,1,2,3] for 4 disease conditions, I notice that the DDPM scheduler can generate 4 distinct images. Yet with DDIM it generates the same image. Can't we use DDIM to emulate the scenario that the images are from the same subject but with different disease conditions?

marksgraham commented 1 month ago

You can swap out the DDPM for DDIM scheduler during sampling. Sometimes, a model producing good samples with DDPM sampling will not produce good samples with DDIM sampling, this might indicate you need to train the model more in my experience