Open 863689877 opened 2 years ago
Hi, @zachteed, thank you for your wonderful work! Maybe you can help when you are free~
Hi, have you understand the backward design of CholeskySolver
, I have no wonder how to write the chain of equation for the backward of CholeskySolver
.
Can you can give me some tips?
class CholeskySolver(torch.autograd.Function):
@staticmethod
def forward(ctx, H, b):
# don't crash training if cholesky decomp fails
try:
U = torch.linalg.cholesky(H)
xs = torch.cholesky_solve(b, U)
ctx.save_for_backward(U, xs)
ctx.failed = False
except Exception as e:
print(e)
ctx.failed = True
xs = torch.zeros_like(b)
return xs
@staticmethod
def backward(ctx, grad_x):
if ctx.failed:
return None, None
U, xs = ctx.saved_tensors
dz = torch.cholesky_solve(grad_x, U)
dH = -torch.matmul(xs, dz.transpose(-1,-2))
return dH, dz
Hi, I find that you use a differentiable BA implemented by PyTorch during training, while using a non differentiable BA implemented by CUDA during testing. What is the reason for this? Is the back pass of PyTorch faster than the cuda rewriting version, and the forward pass slower than the cuda rewriting version?