ashawkey / stable-dreamfusion

Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion.
Apache License 2.0
8.21k stars 721 forks source link

custom backward does not account for gradient scaling #158

Open JunzheJosephZhu opened 1 year ago

JunzheJosephZhu commented 1 year ago

Description

Hi, it seems like the backward() function in SpecifyGradient does not consider gradient scaling. I realized this because I tried to return torch.sum(gt_grad**2), and then skip the u-net differentiable part, but during the backward function execution, I saw that gt_grad is scaled by the square of the gradient scaling factor. However, in the original implementation this is not the behavior, as the gradient is not affected.

Could someone explain why this is written this way?

Steps to Reproduce

class SkipGrad(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, latents, error): print("forward sum", torch.sum(error ** 2)) return latents, error

@staticmethod
@custom_bwd
def backward(ctx, latent_grad, error_grad):
    # latent_grad, error_grad = grad
    print("backward sum", torch.sum(error_grad ** 2))
    error_grad = torch.nan_to_num(error_grad, posinf=1000, neginf=-1000)
    return error_grad, None

Expected Behavior

When you execute this the backward sum printed would be unequal to forward sum printed

Environment

Ubuntu 22

JunzheJosephZhu commented 1 year ago

In other words, according to this repo's implementation the main SDS loss is not being scaled by the gradient scaler, but other losses are being scaled by the gradient scaler.

JunzheJosephZhu commented 1 year ago

confirmed the issue by using the original SDS loss implementation and fixing the loss scaling by setting "self.scaler.update(1.0)". Compared to before when the scaling was not set, the training result looks very different

ashawkey commented 1 year ago

@JunzheJosephZhu This makes sense! Do you observe the results to be worse compared to the correctly scaled implementation? I guess this may fix it, but need to do some experiments for verification:

class SpecifyGradient(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input_tensor, gt_grad):
        ctx.save_for_backward(gt_grad) 
        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) 

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_scale):
        gt_grad, = ctx.saved_tensors
        return grad_scale * gt_grad, None
Jainam2130 commented 1 year ago

Hey there that's some awsome observations! So rre the results better on fixing the loss scaling? Could you please provide a video of the results

JunzheJosephZhu commented 1 year ago

[image: df_ep0100_0005_rgb.png] [image: df_ep0100_0005_rgb.png] for comparison, first image is using the erroneous implementation of gradient scaling, the second one is setting gradient scaling to 1 thereby disabling it. But this is from my own research, so I sort of changed the rendering methods. I haven't run comparison of the original stable dreamfusion's rendering methods though, so I couldn't tell for sure. You can try it out yourself by changing self.scaler.update() to self.scaler.update(1.0) in nerf.utils.train_one_epoch Joseph Zhu

On Sat, Feb 25, 2023 at 2:49 AM Jainam213 @.***> wrote:

Hey there that's some awsome observations! So rre the results better on fixing the loss scaling? Could you please provide a video of the results

— Reply to this email directly, view it on GitHub https://github.com/ashawkey/stable-dreamfusion/issues/158#issuecomment-1445053292, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF2C6G53XLFHSWPFSPH7RWTWZHPTDANCNFSM6AAAAAAVHWA5RU . You are receiving this because you were mentioned.Message ID: @.***>

JunzheJosephZhu commented 1 year ago

df_ep0100_0005_rgb df_ep0100_0005_rgb

ashawkey commented 1 year ago

Thanks for the information! I have updated the code, maybe you are interested in trying it on your experiments too.