lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Question regarding clamping of x_recon in DiffusionPrior and Decoder #39

Closed xiankgx closed 2 years ago

xiankgx commented 2 years ago

DiffusionPrior is configured by default to predict_x_start. As a result, x_recon is not clamped to [-1, 1] which I think is good because we don't know what is the output range of CLIP image embeddings.

https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L770-L778

Decoder is configured by default to not predict_x_start. As a result, x_recon is clamped to [-1, 1] which I think is also good because we fix this in the dataset or dataloader to make model predict image in range [-1, 1]. However, I don't understand why is clamping or not conditioned on not predicting x start? It seems whether or not we predict noise or x_start, x_recon is going to be x_start anyway.

https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L1440-L1446

Please enlighten me.

lucidrains commented 2 years ago

@xiankgx hmm, i don't understand your second paragraph here

the training objective is different depending on whether you are predicting x_start or not https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L1515

lucidrains commented 2 years ago

i am also totally not confident on the new objective, as evidenced by the in-line comments in the code, so if you find another paper that uses this objective, i would definitely be appreciative

lucidrains commented 2 years ago

ok, time to walk my dog 🐕 be back later!

xiankgx commented 2 years ago

I understand the model can either:

However, from the code, we can see that no matter what we are predicting, x_recon is x_start.

def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
        pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)

        if predict_x_start:
            x_recon = pred
        else:
            x_recon = self.predict_start_from_noise(x, t = t, noise = pred)

        if clip_denoised and not predict_x_start:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

x_recon is x_start because of lines:

if predict_x_start:
    x_recon = pred
else:
    x_recon = self.predict_start_from_noise(x, t = t, noise = pred)

Hence, I don't understand then why are we clamping x_recon conditioned on whether model predicts x_start directly or not (via noise) in the following lines:

if clip_denoised and not predict_x_start:
    x_recon.clamp_(-1., 1.)
lucidrains commented 2 years ago

oh I understand! Yes this makes sense for decoder, but not for the diffusion prior (although for prior, do you think we could also clamp with l2norm?)

xiankgx commented 2 years ago

I'm really not sure about that.

lucidrains commented 2 years ago

Thanks I'll make the change once I'm back home

lucidrains commented 2 years ago

I'm really not sure about that.

when in doubt, make it a hyperparameter https://github.com/lucidrains/DALLE2-pytorch/commit/77fa34eae90f6bb321d5f461eca3a9094d1cf225 ;)

lucidrains commented 2 years ago

@xiankgx so i realized the reason i didn't clip in the Decoder is because i introduced latent diffusion - however, there is an improved VQGan variant out there that proposes to l2norm the codebook, so perhaps if we figure out that l2norm clamping works, then we can also add that to the sampling steps as extra guardrail