Stability-AI / stablediffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
38.31k stars 4.94k forks source link

Potential bug: context and y parameters discarded in ImageEmbeddingConditionedLatentDiffusion? #310

Open Assoillyng opened 1 year ago

Assoillyng commented 1 year ago

I might be misunderstanding the implementation, but from what I can trace, it seems like context and y aren't being utilized in the p_losses method. Could someone provide some clarity on this?

Expected Behavior: I expected context and y to be used within the p_losses method, given that they're passed as arguments in the forward method of LatentDiffusion.

Actual Behavior: context and y appear to be discarded and not used in p_losses.

Context: In v2-1-stable-unclip-h-inference.yaml we can see that the target is ImageEmbeddingConditionedLatentDiffusion and the conditioning key is crossattn-adm.

Further, we can see this code in ddpm.DiffusionWrapper:

elif self.conditioning_key == 'crossattn-adm':
    assert c_adm is not None
    cc = torch.cat(c_crossattn, 1)
    out = self.diffusion_model(x, t, context=cc, y=c_adm)

And we see that ImageEmbeddingConditionedLatentDiffusion inherits from LatentDiffusion and has no foward method of its own, so must use the forward method from LatentDiffusion:

def forward(self, x, c, *args, **kwargs):
    t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
    if self.model.conditioning_key is not None:
        assert c is not None
        if self.cond_stage_trainable:
            c = self.get_learned_conditioning(c)
        if self.shorten_cond_schedule:  # TODO: drop this option
            tc = self.cond_ids[t].to(self.device)
            c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
    return self.p_losses(x, c, t, *args, **kwargs)

This passes the keyword arguments (context and y) through to p_losses, which promptly throws them away as it doesn't accept **kwargs:

def p_losses(self, x_start, cond, t, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
    model_output = self.apply_model(x_noisy, t, cond)

    loss_dict = {}
    prefix = 'train' if self.training else 'val'

    if self.parameterization == "x0":
        target = x_start
    elif self.parameterization == "eps":
        target = noise
    elif self.parameterization == "v":
        target = self.get_v(x_start, noise, t)
    else:
        raise NotImplementedError()

    loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
    loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

    logvar_t = self.logvar[t].to(self.device)
    loss = loss_simple / torch.exp(logvar_t) + logvar_t
    # loss = loss_simple / torch.exp(self.logvar) + self.logvar
    if self.learn_logvar:
        loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
        loss_dict.update({'logvar': self.logvar.data.mean()})

    loss = self.l_simple_weight * loss.mean()

    loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
    loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
    loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
    loss += (self.original_elbo_weight * loss_vlb)
    loss_dict.update({f'{prefix}/loss': loss})

    return loss, loss_dict

So we have no embeddings for the images (c_adm) being used anywhere and no cross attention data (cc) being used anywhere.

Am I missing something here or is this a bug?

Assoillyng commented 1 year ago

Turns out I was missing something: self.diffusionmodel is of type openaimodel.UNetModel, and does take context and y. Though I'm still having a lot of trouble getting unclip to run, the implementation seems to be incomplete.

TimSYQQX commented 7 months ago

I second this. The implementation seems to be incomplete.