CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.82k stars 1.53k forks source link

Proper Learning Rate decay for training the LDM #148

Open Adamdad opened 2 years ago

Adamdad commented 2 years ago

In the current repo, I see no learning rate decay in the config files. In the main paper, they are also not mentioned at all. However, when I train your model, I see that the it is not converge well if no learning rate is applied.

Do you apply learning rate decay in your experiments?

clearlyzero commented 11 months ago

May I ask whether you encountered this problem in the first stage or the second stage?

Adamdad commented 11 months ago

Second stage, when training the LDM.

clearlyzero commented 11 months ago

Me too, I also encountered this problem. I trained many times but the loss still remained at 0.4.Have you solved this problem?

Adamdad commented 11 months ago

do lr decay, may help

clearlyzero commented 11 months ago

Is the initial learning rate set to 5e-5?

Adamdad commented 11 months ago

not know your specific application. But default should somehow be 5e-5 * bs_size

clearlyzero commented 11 months ago

Thank you. Thank you very much for your reply. I will try it right away.

clearlyzero commented 11 months ago

I tried it and found that it still didn't work. The loss value still remained at 0.4.

clearlyzero commented 11 months ago

I feel like there is something wrong with my X0 processing x_start = AutoEncoderKl.Encoder(x_start).sample()

def p_losses(self, x_start, t,img_r, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))

    x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
    x_recon = self.denoise_fn(x_noisy, t,img_r)

    if self.loss_type == 'l1':
        loss = (noise - x_recon).abs().mean()
    elif self.loss_type == 'l2':
        loss = F.mse_loss(noise, x_recon)
    else:
        raise NotImplementedError()

    return loss
drx-code commented 2 months ago

Hi, did you solve that problem?