explainingai-code / DiT-PyTorch

This repo implements Diffusion Transformers(DiT) in PyTorch and provides training and inference code on CelebHQ dataset
15 stars 2 forks source link

DiT for Inpainting Task #2

Open Chinafsh opened 3 weeks ago

Chinafsh commented 3 weeks ago

Thanks for u Youtube Video and Code first. BTW, have u find any possibility DiT for Inpainting task.

explainingai-code commented 3 weeks ago

Thank you for the appreciation @Chinafsh I haven't made any efforts towards inpainting task in DIT, but like any diffusion model by simply blending the generated latent pixels for masked regions, together with noisy latent pixels of original image for non-masked regions, should give us a decent starting point.

Like for example when I tried this approach on first image with centre mask, I get these three generation results, obv with actually training for inpainting, it will get significantly better.

Adding sample code to get this result using a DiT trained on celebhq dataset.

original_image_latent, _ = vae.encode(original_image)
for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
    noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))

        noise = torch.randn(original_image_latent.shape).to(device)
        # mask is 1 where we want to generate and 0 where we want to keep
        if i > 0:
                noisy_original_image_latent = scheduler.add_noise(original_image_latent, noise, torch.as_tensor(i - 1).to(device))
                xt = mask * xt + (1 - mask) * noisy_original_image_latent
    else:
                xt = mask * xt + (1 - mask) * original_image_latent

Hope this helps.