jaidevshriram / realmdreamer

Code for RealmDreamer: Text-Driven 3D Scene Generation with Inpainting and Depth Diffusion [Arxiv 2024]
191 stars 6 forks source link

Request for the inpainting model details #3

Closed Rabona17 closed 1 month ago

Rabona17 commented 1 month ago

Dear authors,

Thanks for this awesome work and I am very interested to know about some of the details! May I ask how was the image guidance weight used for stable diffusion 2? Did you use the original sd2-base model and added image_cfg_guidance during sampling and if so, what would be the "null image"?

Please correct me if I misunderstand the paper, and thanks again.

jaidevshriram commented 1 month ago

Hello, thanks for checking it out!

No, we did not use sd2-base but rather sd2-inpainting. This is the inpainting variant of the model that takes as conditioning - text AND a mask. Referencing equation 2 in the paper:

image

Here, the two conditioning is hence the text and mask. During sampling, we hence need three different predictions from the diffusion model's UNet: 1. No text, no mask (no mask is equivalent to inpainting the whole image) 2. No text, w mask and 3. Text, Mask. The image guidance weight is applied to 2. I hope this clears it up slightly. Here is some code that might help too:


        latents = self.encode_images(rgb_BCHW)
        og_latents = latents.clone()
        masked_img = (rgb_BCHW * 2 - 1) * (mask < 0.5)
        masked_img = (masked_img * 0.5 + 0.5).clamp(0, 1)
        masked_latent = self.encode_images(masked_img)
        mask_64 = F.interpolate(
            mask.float(), (64, 64)
        )
        masks = torch.cat([mask_64] * 2)
        masked_latents = torch.cat([masked_latent] * 2)

        # Get the score for no mask (inpaint whole image) and no text prompt
        mask_64_00 = torch.ones_like(mask_64)
        masked_image_00 = torch.ones_like(masked_img) * 0.5
        masked_latent_00 = self.encode_images(masked_image_00)

        with torch.no_grad():
            for i, t in tqdm(enumerate(timesteps), total=len(timesteps), leave=False):

               latent_model_input = torch.cat([latents] * 2)

                latent_model_input = torch.cat([
                    latent_model_input,
                    masks,
                    masked_latents
                ], dim=1)

                noise_pred = self.forward_unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=text_embeddings.to(self.weights_dtype),
                    )

                noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)

                # Compute the unconditional noise
                input = torch.cat([
                    latents,
                    mask_64_00,
                    masked_latent_00
                ], dim=1)

                noise_pred_00 = self.forward_unet(
                            input,
                            t,
                            encoder_hidden_states=text_embeddings[1].unsqueeze(0).to(self.weights_dtype)
                        )

                noise_pred = noise_pred_00 + self.cfg.img_guidance_scale * (noise_pred_uncond - noise_pred_00) + self.cfg.guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latent_dict = self.scheduler.step(noise_pred, t, latents)
                latents = latent_dict["prev_sample"]
                latents_0 = latent_dict["pred_original_sample"]

This is just a snippet and for the loss computation we used DDIM inversion, randomly select timestamps and there's a bit more there - happy to discuss.

Rabona17 commented 1 month ago

Hi Jaidev, Thanks for your comprehensive comments! It's much more clear now and the idea is also really novel for better inpainting. Appreciate your help!