Open Gp1g opened 1 month ago
Hi, I too am trying to replicate those aspects of this project. I do have some baselines for inversion and interpolation with the vanilla ldm-celebahq-256
model.
For inversion, I have used the VQVAE encoder from the diffusers library to get the latent code:
net = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")
[Use a basic denoising loop to get back the image]
For interpolating between images I have used the basic Slerp formula on the latent code:
# Get the angle between unit vectors of the source and target latents
def cos(a, b):
a = a.view(-1)
b = b.view(-1)
a = F.normalize(a, dim=0)
b = F.normalize(b, dim=0)
return (a * b).sum()
batch = torch.cat([image_S, image_T], dim=0)
img_enc = vqvae.encode(batch.to(device))['latents']
noise = torch.randn_like(img_enc).to(device)
noisy_image = scheduler.add_noise(img_enc.to(device), noise.to(device), timesteps) scheduler.set_timesteps(num_inference_steps=1000)
xT = noisy_image.cpu()
alpha = torch.tensor(np.linspace(0, 1,6, dtype=np.float32)).to(xT.device)
theta = torch.arccos(cos(xT[1], xT[0]))
x_shape = xT[0].shape
intp_x = (torch.sin((1 - alpha[:, None]) * theta) * xT[1].flatten(0, 2)[None] + torch.sin(alpha[:, None] * theta) * xT[0].flatten(0, 2)[None]) / torch.sin(theta)
intp_x = intp_x.view(-1, *x_shape)
[Use a denoise loop to get images from intp_x]
I hope the authors will release the checkpoints for the isometric models soon. Best of luck!
Thank you for your interest in our work. We provide a brief description to obtain image interpolation and inversion/reconstruction. In summary, we make use of slerp, DDIMScheduler, and DDIMInverseScheduler imported from diffusers.
We will shortly update our code by including the scripts we have used obtaining the results and the checkpoints. We hope this helps!
Slerp
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
inputs_are_torch = isinstance(v0, torch.Tensor)
if inputs_are_torch:
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
t = t.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
Image interpolation
from diffusers import UNet2DModel, DDIMScheduler, DDIMInverseScheduler
unet = UNet2DModel.from_pretrained(args.unet_path)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
[unet.to](http://unet.to/)(torch_device)
for n, seed in enumerate(seeds):
images_x_slerp = []
generator = torch.Generator(device='cuda')
generator.manual_seed(seed)
noise = torch.randn(
(2, unet.config.in_channels, unet.sample_size, unet.sample_size),
device = torch_device, generator=generator,
)
z0, z1 = noise[0].unsqueeze(0), noise[1].unsqueeze(0)
step = args.step
for w in torch.arange(1e-4, 1 + step, step, device = torch_device):
image_x_slerp = slerp(w, z0, z1)
for i, t in enumerate(scheduler.timesteps):
with torch.no_grad():
residual_x_slerp = unet(image_x_slerp, t)["sample"]
image_x_slerp = scheduler.step(residual_x_slerp, t, image_x_slerp, eta=0.0)["prev_sample"]
images_x_slerp.append(image_x_slerp/2 + 0.5)
grid_x_slerp = make_grid([torch.cat](http://torch.cat/)(images_x_slerp, 0), nrow=int(1/step) + 1, padding=0)
Image inversion/reconstruction
from diffusers import UNet2DModel, DDIMScheduler, DDIMInverseScheduler
unet = UNet2DModel_H.from_pretrained(args.unet_path)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
inverse_scheduler = DDIMInverseScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)
inverse_scheduler.set_timesteps(num_inference_steps=args.num_inference_steps)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
[unet.to](http://unet.to/)(torch_device)
for seed in seeds:
image = read_image(f'path_to_images/image_{seed}.png').to(torch_device).to(torch.float).unsqueeze(0) / 127.5 - 1
for t in inverse_scheduler.timesteps[:args.inversion_steps]:
with torch.no_grad():
residual = unet(image, t)[0]["sample"]
prev_image = inverse_scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
image = prev_image
for t in scheduler.timesteps[-args.inversion_steps:]:
with torch.no_grad():
residual = unet(image, t)[0]["sample"]
prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
image = prev_image
if image_save:
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save(path + f"/recovered_image_{seed}.png")
Thanks for your help. But when I use the image inversion/reconstruction code, both the checkpoint you provide and the official unet, I get bad images. Are there some changes needed to reproduce the results?
Yes, as you mentioned, we also encountered poor performance with DDPM models in the inversion task. This issue has been highlighted by many previous works. To address this, we used ADM models for a visual comparison between the inversion and reconstruction tasks of our approach versus the baseline.
We've updated the code to support ADM inversion and reconstruction and uploaded the corresponding weights. Please check the new codes and the updated README.md for pretrained weights link.
Thanks.
Hello,
First of all, thank you for sharing your work and providing such a valuable resource. I am currently exploring your project and am particularly interested in understanding how to get the results related to Image Interpolation, Image Inversion, and Reconstruction.
Could you please provide some guidance or a brief walkthrough on how to perform these tasks using your codebase? Any example scripts, configurations, or specific instructions would be greatly appreciated.
Thank you for your time and assistance.
Best regards,