threestudio-project / threestudio

A unified framework for 3D content generation.
Apache License 2.0
5.93k stars 457 forks source link

A question about gradients [really confused] #424

Open ByChelsea opened 5 months ago

ByChelsea commented 5 months ago

While running DreamFusion, I wanted to delve deeper into the values of gradients, but I found that the following two values are different (though I merely decomposed it manually using the chain rule?):

grad_0 = torch.autograd.grad(0.5 * F.mse_loss(noise_pred, noise, reduction="sum"), latents, retain_graph=True)[0] grad_1 = torch.autograd.grad(0.5 * F.mse_loss(noise_pred, noise, reduction="sum"), noise_pred, retain_graph=True)[0] * torch.autograd.grad(noise_pred, latents, grad_outputs=torch.ones_like(noise_pred), retain_graph=True)[0]

The complete code looks like this:

    def compute_grad_sds(
        self,
        latents: Float[Tensor, "B 4 64 64"],
        t: Int[Tensor, "B"],
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
    ):
        batch_size = elevation.shape[0]

        if prompt_utils.use_perp_neg:  # False
            (
                text_embeddings,
                neg_guidance_weights,
            ) = prompt_utils.get_text_embeddings_perp_neg(
                elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
            )
            with torch.no_grad():
                noise = torch.randn_like(latents)
                latents_noisy = self.scheduler.add_noise(latents, noise, t)
                latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
                noise_pred = self.forward_unet(
                    latent_model_input,
                    torch.cat([t] * 4),
                    encoder_hidden_states=text_embeddings,
                )  # (4B, 3, 64, 64)

            noise_pred_text = noise_pred[:batch_size]
            noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
            noise_pred_neg = noise_pred[batch_size * 2 :]

            e_pos = noise_pred_text - noise_pred_uncond
            accum_grad = 0
            n_negative_prompts = neg_guidance_weights.shape[-1]
            for i in range(n_negative_prompts):
                e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
                accum_grad += neg_guidance_weights[:, i].view(
                    -1, 1, 1, 1
                ) * perpendicular_component(e_i_neg, e_pos)

            noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
                e_pos + accum_grad
            )
        else:
            neg_guidance_weights = None
            text_embeddings = prompt_utils.get_text_embeddings(
                elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
            )  
            # add noise
            noise = torch.randn_like(latents)  # TODO: use torch generator
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 2, dim=0)  # [2, 4, 64, 64]
            noise_pred = self.forward_unet(
                latent_model_input,
                torch.cat([t] * 2),  # shape [2]
                encoder_hidden_states=text_embeddings,
            )  # SDS with U-Net Jacobian

            # perform guidance (high scale from paper!)
            noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)  # 条件/非条件噪声
            noise_pred = noise_pred_text + self.cfg.guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        if self.cfg.weighting_strategy == "sds":
            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
        elif self.cfg.weighting_strategy == "uniform":
            w = 1
        elif self.cfg.weighting_strategy == "fantasia3d":
            w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
        else:
            raise ValueError(
                f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
            )

        grad_0 = torch.autograd.grad(0.5 * F.mse_loss(noise_pred, noise, reduction="sum"), latents, retain_graph=True)[0] 
        grad_1 = torch.autograd.grad(0.5 * F.mse_loss(noise_pred, noise, reduction="sum"), noise_pred, retain_graph=True)[0] * torch.autograd.grad(noise_pred, latents, grad_outputs=torch.ones_like(noise_pred), retain_graph=True)[0]
        import pdb
        pdb.set_trace()
        grad = w * (noise_pred.detach() - noise)

        guidance_eval_utils = {
            "use_perp_neg": prompt_utils.use_perp_neg,
            "neg_guidance_weights": neg_guidance_weights,
            "text_embeddings": text_embeddings,
            "t_orig": t,
            "latents_noisy": latents_noisy,
            "noise_pred": noise_pred,
        }

        return grad, guidance_eval_utils

I am very confused. The discrepancy is not just a matter of rounding errors. Can someone please help me?

OrangeSodahub commented 3 months ago

I don't understand what grad_0 and grad_1 stand for your're thinking about, they are not the same thing here