princeton-vl / DROID-SLAM

BSD 3-Clause "New" or "Revised" License
1.75k stars 295 forks source link

Different implementations of BA #70

Open 863689877 opened 2 years ago

863689877 commented 2 years ago

Hi, I find that you use a differentiable BA implemented by PyTorch during training, while using a non differentiable BA implemented by CUDA during testing. What is the reason for this? Is the back pass of PyTorch faster than the cuda rewriting version, and the forward pass slower than the cuda rewriting version?

863689877 commented 2 years ago

Hi, @zachteed, thank you for your wonderful work! Maybe you can help when you are free~

GuoPingPan commented 1 year ago

Hi, have you understand the backward design of CholeskySolver, I have no wonder how to write the chain of equation for the backward of CholeskySolver. Can you can give me some tips?

class CholeskySolver(torch.autograd.Function):
    @staticmethod
    def forward(ctx, H, b):
        # don't crash training if cholesky decomp fails
        try:
            U = torch.linalg.cholesky(H)
            xs = torch.cholesky_solve(b, U)
            ctx.save_for_backward(U, xs)
            ctx.failed = False
        except Exception as e:
            print(e)
            ctx.failed = True
            xs = torch.zeros_like(b)

        return xs

    @staticmethod
    def backward(ctx, grad_x):
        if ctx.failed:
            return None, None

        U, xs = ctx.saved_tensors
        dz = torch.cholesky_solve(grad_x, U)
        dH = -torch.matmul(xs, dz.transpose(-1,-2))

        return dH, dz