Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
555 stars 78 forks source link

Conditioning mode variable name #435

Closed liamchalcroft closed 7 months ago

liamchalcroft commented 7 months ago

Have found an issue when using MONAI's SlidingWindowInferer in combination with the DiffusionInferer where passing a value for the conditioning 'mode' variable throws an error due to SlidingWindow having its own variable 'mode' for the patch merging.

I think it should be resolved by renaming the 'mode' variable here to e.g. 'cond_mode'

marksgraham commented 7 months ago

Hi,

Thanks for raising an issue. So are you calling the SlidingWindowInferer where the argument network is the .sample() function from the DiffuisionInferer? If possible providing a minimal example of what you're running would help us investigate.

liamchalcroft commented 7 months ago

Hi Mark,

Yep precisely - I am using it in a translation task where we are applying the diffusion model patch-wise over the image space rather than on the latent of an auto encoder. I imagine it could be used in the same manner for e.g. segmentation. Here is a minimal example:

` from generative.inferers import DiffusionInferer from monai.inferers import SlidingWindowInferer from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet from generative.networks.schedulers.ddim import DDIMScheduler

scheduler = DDIMScheduler(num_train_timesteps=1000) model = DiffusionModelUNet( spatial_dims=3, in_channels=2, out_channels=1, num_channels=(16, 32, 64), attention_levels=(False, False, False), num_head_channels=(0, 0, 0), norm_num_groups=16, ) inferer = DiffusionInferer(scheduler) window = SlidingWindowInferer(roi_size=(32,32,32), sw_batch_size=8, overlap=0.25, mode='gaussian')

input_src = torch.ones(1,1,96,96,96) output = window(torch.randn_like(input_src[:,0][:,None][None]), inferer.sample, conditioning=input_src, cond_mode='concat', verbose=False)[0] `

marksgraham commented 7 months ago

Thanks for the example. @ericspod any smart ideas on how we could get around this error? If not I think we should rename the argument.

BTW had to slightly modify the example to get it working on newer monai gen, its;

import torch
from generative.inferers import DiffusionInferer
from monai.inferers import SlidingWindowInferer
from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet
from generative.networks.schedulers.ddim import DDIMScheduler

scheduler = DDIMScheduler(num_train_timesteps=1000)
model = DiffusionModelUNet(
spatial_dims=3,
in_channels=2,
out_channels=1,
num_channels=(16, 32, 64),
    num_res_blocks=1,
attention_levels=(False, False, False),
norm_num_groups=16,
)
inferer = DiffusionInferer(scheduler)
window = SlidingWindowInferer(roi_size=(32,32,32), sw_batch_size=8, overlap=0.25, mode='gaussian')

input_src = torch.ones(1,1,96,96,96)
network_args = {'mode':'concat', 'conditioning':input_src, 'verbose':False}

output = window(torch.randn_like(input_src[:,0][:,None]),
inferer.sample,
conditioning=input_src,
mode='concat',
verbose=False)
ericspod commented 7 months ago

You can wrap the inferer.sample object with partial to hard-code the mode as "concat" and pass that to the window call with the mode argument removed from call.

output = window(torch.randn_like(input_src[:,0][:,None]),
partial(inferer.sample, mode='concat'),
conditioning=input_src,
verbose=False)

Something like that should work.

liamchalcroft commented 7 months ago

You can wrap the inferer.sample object with partial to hard-code the mode as "concat" and pass that to the window call with the mode argument removed from call.

output = window(torch.randn_like(input_src[:,0][:,None]),
partial(inferer.sample, mode='concat'),
conditioning=input_src,
verbose=False)

Something like that should work.

Thanks, hadn't thought of that! Will leave the issue in case you do want to avoid the conflict in this fairly rare use case, but otherwise happy to close on my end.

marksgraham commented 7 months ago

Closing as I suspect this use case will be fairly rare and we have a solution, but will revisit it more people flag this as an issue.