Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
593 stars 82 forks source link

why my loss function in diffusion Model doesnt get improved? #394

Closed AbrahamGhavabesh closed 1 year ago

AbrahamGhavabesh commented 1 year ago

@marksgraham Hi, I am working on a super-resolution purpose with Autoencoderkl&Diffusion Model, aimed to get 128128 to 256256 for burst tumor monai datasets, but my loss function in diffusion model stuck at around 0.28-0.26 without improvements. Trying increasing time step from 1000 to 2000 will decreases the loss to 0.2 but it demolished the validation image test and showed no improvement in images so time_steps=1000 seems fine.With this condition, what shall I do to decrease loss function?? here is my code: My model:

unet = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=4,
    out_channels=3,
    num_res_blocks=2,
    num_channels=(256, 512, 1024),
    attention_levels=(False, True, True),
    num_head_channels=( 0, 256, 256),
)
unet = unet.to(device)
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0001, beta_end=0.0195)
low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0001, beta_end=0.0195)
max_noise_level = 350

Training Diffusion Model:

optimizer = torch.optim.Adam(unet.parameters(), lr=5e-5)

scaler_diffusion = GradScaler()

n_epochs = 1 #200
val_interval = 1 #20
epoch_loss_list = []
val_epoch_loss_list = []

####

for epoch in range(n_epochs):
    unet.train()
    autoencoderkl.eval()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch+c}")
    for step, batch in progress_bar:
        images = batch["image"].to(device)
        low_res_image = batch["low_res_image"].to(device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            with torch.no_grad():
                latent = autoencoderkl.encode_stage_2_inputs(images)* scale_factor 

            # Noise augmentation
            noise = torch.randn_like(latent).to(device)
            low_res_noise = torch.randn_like(low_res_image).to(device)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device).long()
            low_res_timesteps = torch.randint(
                0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
            ).long()

            noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
            noisy_low_res_image = scheduler.add_noise(original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps)

            #print("images size = ", images.size()) ####new
            #print("latent size = ", latent.size()) ####new
            #print("noise size = ", noise.size()) ####new
            #print("noisy_latent size = ", noisy_latent.size()) ####new
            #print("noisy_low_res_image size = ", noisy_low_res_image.size()) ####new

            latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)

            noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
            loss = F.mse_loss(noise_pred.float(), noise.float())

        scaler_diffusion.scale(loss).backward()
        scaler_diffusion.step(optimizer)
        scaler_diffusion.update()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
    epoch_loss_list.append(epoch_loss / (step + 1))

    if (epoch + 1) % val_interval == 0:
        unet.eval()
        val_loss = 0
        for val_step, batch in enumerate(val_loader, start=1):
            images = batch["image"].to(device)
            low_res_image = batch["low_res_image"].to(device)

            with torch.no_grad():
                with autocast(enabled=True):
                    latent = autoencoderkl.encode_stage_2_inputs(images) * scale_factor
                    # Noise augmentation
                    noise = torch.randn_like(latent).to(device)
                    low_res_noise = torch.randn_like(low_res_image).to(device)
                    timesteps = torch.randint(
                        0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device
                    ).long()
                    low_res_timesteps = torch.randint(
                        0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
                    ).long()

                    noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
                    noisy_low_res_image = scheduler.add_noise(
                        original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps
                    )

                    latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)
                    noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
                    loss = F.mse_loss(noise_pred.float(), noise.float())

            val_loss += loss.item()
        val_loss /= val_step
        val_epoch_loss_list.append(val_loss)
        print(f"Epoch {epoch} val loss: {val_loss:.4f}")

        # Sampling image during training
        sampling_image = low_res_image[0].unsqueeze(0)
        latents = torch.randn((1, 3, 128, 128)).to(device)
        low_res_noise = torch.randn((1, 1, 128, 128)).to(device)
        noise_level = 20
        noise_level = torch.Tensor((noise_level,)).long().to(device)
        noisy_low_res_image = scheduler.add_noise(
            original_samples=sampling_image,
            noise=low_res_noise,
            timesteps=torch.Tensor((noise_level,)).long().to(device),
        )

        scheduler.set_timesteps(num_inference_steps=1000)
        for t in tqdm(scheduler.timesteps, ncols=110):
            with torch.no_grad():
                with autocast(enabled=True):
                    latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)
                    noise_pred = unet(
                        x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), class_labels=noise_level
                    )
                latents, _ = scheduler.step(noise_pred, t, latents)

        with torch.no_grad():
            decoded = autoencoderkl.decode_stage_2_outputs(latents / scale_factor)

        low_res_bicubic = nn.functional.interpolate(sampling_image, (256, 256), mode="bicubic")

my result after 15 epoch is below: Capture

thanks for any help???

marksgraham commented 1 year ago

Hi,

AbrahamGhavabesh commented 1 year ago
  1. How many epochs have you tried training for? Is there still no change after 100? every epoch with batch_size=1 takes 30 min to finish so up to now I just reached epoch 18 for training. 2.How do your autoencoder reconstructions look - are they OK? it worked, last time you guide me and i trained the autoencoderkl for loss function 0.008. the reconstaction image looks fine.
    
    #Define Autoencoderkl
    autoencoderkl = AutoencoderKL(
    spatial_dims=2,               
    in_channels=1,              
    out_channels=1,               
    num_channels=(128, 128), 
    latent_channels=3,           
    num_res_blocks=2,
    norm_num_groups=32,
    attention_levels=(False,False),

) autoencoderkl = autoencoderkl.to(device)

discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, num_channels=64) discriminator = discriminator.to(device)


tho autoencoderkl output is well trained as shown here:
![autoecoder6epoch](https://github.com/Project-MONAI/GenerativeModels/assets/118294105/812cdf6c-c2b1-4122-8b34-37ea396bf0a7)

3. Does your code work if you try to upsample from 16^2 to 64^2, as in the tutorial? If so, we know it isn't a bug in your code and it's a matter of working out how to adapt to higher res. If it doesn't work then you need to work out why the tutorial works and your code doesn't.
 it works for low res such as 32&64 but because I used this code for higher res 128&256 I think unet /num_train_timesteps/beta_start/beta_end/max_noise_level ... should be modified
I've noticed your beta_start is a lot lower than the tutorial too, try beta_start=0.0015 - though I'd be surprised if this was the difference.
4. I've noticed your beta_start is a lot lower than the tutorial too, try beta_start=0.0015 - though I'd be surprised if this was the difference.??
 I choose to decrease beta_start from 0.0015 to 0.0001 because my ground truth image(256*256) just resized to 128*128 without rotating any pixel or other things so i anticipate my model should train on lower beta if i am right.
my output for now is like this:

![Capture](https://github.com/Project-MONAI/GenerativeModels/assets/118294105/25328c6f-13cf-4913-b61b-424b39a7600f)
marksgraham commented 1 year ago

That last image you shared - is that the prediction from the DDPM? Id so it doesn't look too bad...

Also, in the tutorial the loss plateaus at about 0.1. I would expect your loss plateau to be at a higher value because you're working on a different, and harder problem. But given your samples look OK, I think things are fine and you just need to keep training. Even if the loss barely changes it can be worth training DDPM for longer; very small changes in the loss can lead to noticeable improvements during sampling

AbrahamGhavabesh commented 1 year ago

yes, the third image which name is High res_images is the DDPM prediction and it looks fine but the precision need to be improved how is it possible that the output becomes better but the loss is stuck at 0.26??! in the tutorial the loss changes immediately after one epoch from 0.285 to 0.16 and after 19 epochs becomes 0.13. but my loss is without any changes after 18 epochs?

marksgraham commented 1 year ago

I think even very small changes in loss (that you can't see at two decimal precision) can have an effect because that improvement is applied cumulatively over 1000 steps during sampling. I would monitor the samples and see if they look to be improving. I odn't think you can completely compare your loss dynamics to the tutorial because it's much harder working on larger image sizes

AbrahamGhavabesh commented 1 year ago

it said here in this paper"Cascaded Diffusion Models for High Fidelity Image Generation": Capture Based on the above paper I should use 2000 or 4000 time steps for training Diffusion and 100 Timesteps for Inference But when I do it, the images become messy however the loss function becomes smaller than 0.02. So I wonder which parameters are correct that I should keep training on them?? Merci for your help

marksgraham commented 1 year ago

Hi

I think if you train with 2000 timesteps you might need to train for longer before you start getting good samples. I think you are fixating on the loss too much. You're interested in sample quality; the loss does not directly measure that.

It might also be worth checking samples with the DDPM sampler without reducing the number of inference steps here. As you can see in this tutorial, in the early stages of training you get better samples from the DDPM sampler and it takes a while before the PNDM sampler starts to give reasonable sample.