CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.65k stars 1.52k forks source link

Get loss=nan when finetune VAE #176

Open eeyrw opened 1 year ago

eeyrw commented 1 year ago

I found here cause nan: ldm/modules/losses/contperceptual.py

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight
AlexWortega commented 1 year ago

you should try fp32 vae and optimizer

eeyrw commented 1 year ago

It's really amazing that you know I use adam8bit and fp16.

keyu-tian commented 1 year ago

@eeyrw do u see any improvements after you finetuned the vae?

eeyrw commented 1 year ago

No. I have no sufficient GPU ram so fail to make further try.

keyu-tian commented 1 year ago

@eeyrw i got nan too but not there. it was in https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py#L117. I solved nan by replacing that line with torch.sqrt(torch.sum(x**2,dim=1,keepdim=True) + eps).

eeyrw commented 1 year ago

@keyu-tian Nice eps trick improves numerical stability a lot 😀