lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.07k stars 1.08k forks source link

ddim makes the generation worse #231

Open YUHANG-Ma opened 2 years ago

YUHANG-Ma commented 2 years ago

Hi, I met an issue that when I use ddim for the decoder sampling, the pics don't look good. image When I change the sample step to 1000, it comes to the following result. image Could I ask how to fix it?

The following is the ddim part of my code.

` def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): device = self.betas.device

    b = shape[0]
    img = torch.randn(shape, device = device)
    timesteps = 250
    times = torch.linspace(0., 1000, steps = timesteps + 2)[:-1]

    times = list(reversed(times.int().tolist()))
    time_pairs = list(zip(times[:-1], times[1:]))
    print(time_pairs)
    alphas = self.alphas_cumprod_prev
    if not is_latent_diffusion:
        lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

    for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
        alpha = alphas[time]
        alpha_next = alphas[time_next]

        # print("alpha_next",alpha_next)
        # print("alpha_next1",alpha_next1)

        time_cond = torch.full((b,), time, device = device, dtype = torch.long)

        pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)

        if learned_variance:
            pred, _ = pred.chunk(2, dim = 1)

        if predict_x_start:
            x_start = pred
            pred_noise = self.predict_noise_from_start(img, t = time_cond, x0 = pred)
        else:
            x_start = self.predict_start_from_noise(img, t = time_cond, noise = pred)
            pred_noise = pred

        if clip_denoised:
            s = 1.
            # clip by threshold, depending on whether static or dynamic
            x_start = x_start.clamp(-s, s) / s

        c1 = 1 * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
        c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
        noise = torch.randn_like(img) if time_next > 0 else 0.

        img = x_start * alpha_next.sqrt() + \
              c1 * noise + \
              c2 * pred_noise

    img = self.unnormalize_img(img)
    return img`
FTKyaoyuan commented 2 years ago

can you tell me the dataset you used thanks

YUHANG-Ma commented 2 years ago

can you tell me the dataset you used thanks

pics from Internet

FTKyaoyuan commented 2 years ago

can you tell me the dataset you used thanks

pics from Internet

Can I see the code you trained