Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
555 stars 78 forks source link

Step ratio in scheduler code #466

Closed Ahmad-Omar-Ahsan closed 4 months ago

Ahmad-Omar-Ahsan commented 4 months ago

https://github.com/Project-MONAI/GenerativeModels/blob/a473b5f7db4be41129a4ee2f839d4a09cfdc1595/generative/networks/schedulers/ddpm.py#L119

Hello,

I'd like to know why the code has this step.

Let's say I train a diffusion model for 1000 timesteps (train_timesteps=1000). During training, I want to add 20% noise and then use the Inferer and Scheduler to remove that 20% noise. In your DDPM tutorial, you set the inference timestep in the scheduler and then call inferer.sample to generate the images.

scheduler.set_timesteps(num_inference_steps=1000)
with autocast(enabled=True):
    image, intermediates = inferer.sample(
        input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100
    )

But because of the step ratio if I set the inference step as 200, instead of getting an array of t starting from 200 to 0 I get an array that starts from 995 to 0. The problem is that it jumps every 5 timesteps and it's not giving me the desired array of timesteps.

Could you explain why the step ratio was added?

The second thing I wanted to ask was. If I want to create an array of timesteps from [200 to 0] and set it to my scheduler.timestep what would be the best way to do it? Should I create a function and then assign that array to scheduler.timestep or directly modify the scheduler code by adding the function?

marksgraham commented 4 months ago

The step ratio is for sampling in an accelerated fashion, i.e. moving from t=T to t=0 whilst skipping every n steps, where T here is 1000.

You want to do something more akin to reconstruction - partially noise, then denoise? You'll have to do that outside the inferer, it isn't built for that but it is straightforward to do. I've got some code in another project where I do it

https://github.com/marksgraham/ddpm-ood/blob/e5b9ead8405c0b9f1f9cb906e5ba0f83a0698158/src/trainers/reconstruct.py#L143-L157

here t_start is where you want to noise up to (for you, t=200) and then you manually iterate through, only selecting the timesteps <=200

You will see there is no need to overwrite scheduler.timestep

Ahmad-Omar-Ahsan commented 4 months ago

Thank you. In my implementation, I created a function that doesn't use the step ratio to return the timesteps in descending order [200,..0]. Then I set it to scheduler.timestep and passed it to inferer.sample. I am using PyTorch Lightning to train the MONAI DDPM code. So the hook function on_train_epoch_end, implements that. It looks like this.

def on_train_epoch_end(self):
        images = self.batches[0]
        current_epoch = self.current_epoch + 1
        if current_epoch % 50 == 0:
            print(f'On training epoch:{self.current_epoch} end\n')
            int_timesteps = int(0.2 * self.inferer.scheduler.num_train_timesteps)
            timesteps = torch.tensor(int_timesteps, dtype=torch.long)

            noise = torch.randn((images.shape[0], images.shape[1], images.shape[2], images.shape[3]), device=images.device)
            noisy_image =  self.scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)

            self.scheduler.timesteps = set_timesteps_without_ratio(num_inference_steps=int_timesteps, device=images.device)
            images_denoised = self.inferer.sample(input_noise=noisy_image, diffusion_model=self.model, scheduler=self.scheduler)

            grid_noisy = make_grid(noisy_image, nrow=4)
            grid_train_images = make_grid(images, nrow=4)
            grid_denoised_images = make_grid(images_denoised, nrow=4)

            self.logger.experiment.add_image(f"Training images", grid_train_images, current_epoch)
            self.logger.experiment.add_image(f"Noised train images", grid_noisy, current_epoch)
            self.logger.experiment.add_image(f"Denoised train images", grid_denoised_images, current_epoch)

        self.batches.clear()