Closed liamchalcroft closed 7 months ago
Hi,
Thanks for raising an issue. So are you calling the SlidingWindowInferer
where the argument network
is the .sample()
function from the DiffuisionInferer
? If possible providing a minimal example of what you're running would help us investigate.
Hi Mark,
Yep precisely - I am using it in a translation task where we are applying the diffusion model patch-wise over the image space rather than on the latent of an auto encoder. I imagine it could be used in the same manner for e.g. segmentation. Here is a minimal example:
` from generative.inferers import DiffusionInferer from monai.inferers import SlidingWindowInferer from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet from generative.networks.schedulers.ddim import DDIMScheduler
scheduler = DDIMScheduler(num_train_timesteps=1000) model = DiffusionModelUNet( spatial_dims=3, in_channels=2, out_channels=1, num_channels=(16, 32, 64), attention_levels=(False, False, False), num_head_channels=(0, 0, 0), norm_num_groups=16, ) inferer = DiffusionInferer(scheduler) window = SlidingWindowInferer(roi_size=(32,32,32), sw_batch_size=8, overlap=0.25, mode='gaussian')
input_src = torch.ones(1,1,96,96,96) output = window(torch.randn_like(input_src[:,0][:,None][None]), inferer.sample, conditioning=input_src, cond_mode='concat', verbose=False)[0] `
Thanks for the example. @ericspod any smart ideas on how we could get around this error? If not I think we should rename the argument.
BTW had to slightly modify the example to get it working on newer monai gen, its;
import torch
from generative.inferers import DiffusionInferer
from monai.inferers import SlidingWindowInferer
from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet
from generative.networks.schedulers.ddim import DDIMScheduler
scheduler = DDIMScheduler(num_train_timesteps=1000)
model = DiffusionModelUNet(
spatial_dims=3,
in_channels=2,
out_channels=1,
num_channels=(16, 32, 64),
num_res_blocks=1,
attention_levels=(False, False, False),
norm_num_groups=16,
)
inferer = DiffusionInferer(scheduler)
window = SlidingWindowInferer(roi_size=(32,32,32), sw_batch_size=8, overlap=0.25, mode='gaussian')
input_src = torch.ones(1,1,96,96,96)
network_args = {'mode':'concat', 'conditioning':input_src, 'verbose':False}
output = window(torch.randn_like(input_src[:,0][:,None]),
inferer.sample,
conditioning=input_src,
mode='concat',
verbose=False)
You can wrap the inferer.sample
object with partial to hard-code the mode as "concat" and pass that to the window
call with the mode
argument removed from call.
output = window(torch.randn_like(input_src[:,0][:,None]),
partial(inferer.sample, mode='concat'),
conditioning=input_src,
verbose=False)
Something like that should work.
You can wrap the
inferer.sample
object with partial to hard-code the mode as "concat" and pass that to thewindow
call with themode
argument removed from call.output = window(torch.randn_like(input_src[:,0][:,None]), partial(inferer.sample, mode='concat'), conditioning=input_src, verbose=False)
Something like that should work.
Thanks, hadn't thought of that! Will leave the issue in case you do want to avoid the conflict in this fairly rare use case, but otherwise happy to close on my end.
Closing as I suspect this use case will be fairly rare and we have a solution, but will revisit it more people flag this as an issue.
Have found an issue when using MONAI's SlidingWindowInferer in combination with the DiffusionInferer where passing a value for the conditioning 'mode' variable throws an error due to SlidingWindow having its own variable 'mode' for the patch merging.
I think it should be resolved by renaming the 'mode' variable here to e.g. 'cond_mode'