isno0907 / isodiff

7 stars 0 forks source link

Request for Guidance on Achieving Image Interpolation, Inversion, and Reconstruction Results #1

Open Gp1g opened 1 month ago

Gp1g commented 1 month ago

Hello,

First of all, thank you for sharing your work and providing such a valuable resource. I am currently exploring your project and am particularly interested in understanding how to get the results related to Image Interpolation, Image Inversion, and Reconstruction.

Could you please provide some guidance or a brief walkthrough on how to perform these tasks using your codebase? Any example scripts, configurations, or specific instructions would be greatly appreciated.

Thank you for your time and assistance.

Best regards,

AvirupJU commented 2 weeks ago

Hi, I too am trying to replicate those aspects of this project. I do have some baselines for inversion and interpolation with the vanilla ldm-celebahq-256 model.

For inversion, I have used the VQVAE encoder from the diffusers library to get the latent code:

net = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

[Use a basic denoising loop to get back the image]

For interpolating between images I have used the basic Slerp formula on the latent code:

# Get the angle between unit vectors of the source and target latents
def cos(a, b):
    a = a.view(-1)
    b = b.view(-1)
    a = F.normalize(a, dim=0)
    b = F.normalize(b, dim=0)
    return (a * b).sum()

batch = torch.cat([image_S, image_T], dim=0)
img_enc = vqvae.encode(batch.to(device))['latents']
noise = torch.randn_like(img_enc).to(device)
noisy_image = scheduler.add_noise(img_enc.to(device), noise.to(device), timesteps) scheduler.set_timesteps(num_inference_steps=1000)
xT = noisy_image.cpu()
alpha = torch.tensor(np.linspace(0, 1,6, dtype=np.float32)).to(xT.device)
theta = torch.arccos(cos(xT[1], xT[0]))
x_shape = xT[0].shape
intp_x = (torch.sin((1 - alpha[:, None]) * theta) * xT[1].flatten(0, 2)[None] + torch.sin(alpha[:, None] * theta) * xT[0].flatten(0, 2)[None]) / torch.sin(theta)

intp_x = intp_x.view(-1, *x_shape)

[Use a denoise loop to get images from intp_x]

I hope the authors will release the checkpoints for the isometric models soon. Best of luck!

Jaehoon-zx commented 1 week ago

Thank you for your interest in our work. We provide a brief description to obtain image interpolation and inversion/reconstruction. In summary, we make use of slerp, DDIMScheduler, and DDIMInverseScheduler imported from diffusers.

We will shortly update our code by including the scripts we have used obtaining the results and the checkpoints. We hope this helps!

Slerp

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    inputs_are_torch = isinstance(v0, torch.Tensor)
    if inputs_are_torch:
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()
        t = t.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

Image interpolation

from diffusers import UNet2DModel, DDIMScheduler, DDIMInverseScheduler

unet = UNet2DModel.from_pretrained(args.unet_path)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
[unet.to](http://unet.to/)(torch_device)

for n, seed in enumerate(seeds):
    images_x_slerp = []

    generator = torch.Generator(device='cuda')
    generator.manual_seed(seed)
    noise = torch.randn(
        (2, unet.config.in_channels, unet.sample_size, unet.sample_size),
        device = torch_device, generator=generator,
    )
    z0, z1 = noise[0].unsqueeze(0), noise[1].unsqueeze(0)

    step = args.step
    for w in torch.arange(1e-4, 1 + step, step, device = torch_device):
        image_x_slerp = slerp(w, z0, z1)

        for i, t in enumerate(scheduler.timesteps):
            with torch.no_grad():
                residual_x_slerp = unet(image_x_slerp, t)["sample"]

            image_x_slerp = scheduler.step(residual_x_slerp, t, image_x_slerp, eta=0.0)["prev_sample"]

        images_x_slerp.append(image_x_slerp/2 + 0.5)

    grid_x_slerp = make_grid([torch.cat](http://torch.cat/)(images_x_slerp, 0), nrow=int(1/step) + 1, padding=0)

Image inversion/reconstruction

from diffusers import UNet2DModel, DDIMScheduler, DDIMInverseScheduler

unet = UNet2DModel_H.from_pretrained(args.unet_path)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
inverse_scheduler = DDIMInverseScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)
inverse_scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
[unet.to](http://unet.to/)(torch_device)

for seed in seeds:
    image = read_image(f'path_to_images/image_{seed}.png').to(torch_device).to(torch.float).unsqueeze(0) / 127.5 - 1

    for t in inverse_scheduler.timesteps[:args.inversion_steps]:
        with torch.no_grad():
            residual = unet(image, t)[0]["sample"]

        prev_image = inverse_scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
        image = prev_image

    for t in scheduler.timesteps[-args.inversion_steps:]:
        with torch.no_grad():
            residual = unet(image, t)[0]["sample"]

        prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
        image = prev_image

    if image_save:
        image_processed = image.cpu().permute(0, 2, 3, 1)
        image_processed = (image_processed + 1.0) * 127.5
        image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
        image_pil = PIL.Image.fromarray(image_processed[0])
        image_pil.save(path + f"/recovered_image_{seed}.png")
Gp1g commented 3 days ago

Thanks for your help. But when I use the image inversion/reconstruction code, both the checkpoint you provide and the official unet, I get bad images. Are there some changes needed to reproduce the results?

Results:

recovered_image_0

Input image:

00000

isno0907 commented 2 days ago

Yes, as you mentioned, we also encountered poor performance with DDPM models in the inversion task. This issue has been highlighted by many previous works. To address this, we used ADM models for a visual comparison between the inversion and reconstruction tasks of our approach versus the baseline.

We've updated the code to support ADM inversion and reconstruction and uploaded the corresponding weights. Please check the new codes and the updated README.md for pretrained weights link.

Thanks.