rfeinman / pytorch-minimize

Newton and Quasi-Newton optimization with PyTorch
https://pytorch-minimize.readthedocs.io
MIT License
292 stars 34 forks source link

Help with RunTimeError #31

Closed bfialkoff closed 4 months ago

bfialkoff commented 6 months ago

I am using this project in conjunction with pytorch3d. I initially used scipy's least_squares and it worked fairly well. I decided to use torch_min.least_squares to take advantage of torch's autograd which should work with pytorch3d's differential renderer.

Basically, I have a function that returns a 3d Mesh that is linearly-dependant and a weight vector. The idea is to find a vector that produces a mesh, whose depth map matches a ground truth depth map. Basically: $\underset{weights}{argmin} ||gt \textunderscore depth \textunderscore map - render \textunderscore depth(get \textunderscore mesh(weights))||$

Here is my code:

from torchmin import least_squares as torch_least_squares
def get_mesh(self, weights: torch.Tensor, scale=1.):
    weights = torch.atleast_2d(weights)
    num_meshes = weights.shape[0]
    shape_space = self.vector_space_basis.type(weights.dtype)
    assert self.vector_space_basis.shape[0] == weights.shape[1]

    displacement_vertices = torch.mm(weights, vector_space_basis).view(weights.shape[0], -1, 3)
    new_veritices = self.average_mesh_torch.verts_packed().unsqueeze(0) + displacement_vertices
    out_mesh = Meshes(verts=new_veritices * scale, faces=torch.cat(num_meshes * [self.average_mesh_torch.faces_packed().unsqueeze(0)]))
    return out_mesh

 def torch_cost(weights, gt_depth_map, renderer, scale):
    pred_mesh = get_mesh(weights, scale=scale, use_reduced_space=True)
    pred_depth_map = renderer(pred_mesh)
    cost = torch.square(pred_depth_map - gt_depth_map).view(-1)
    return cost
x = torch.zeros(n_components, requires_grad=True)
res = torch_least_squares(lambda x: torch_cost(x, gt_depth_map, renderer_depth, dummy_scale), x)

This fails with the following stacktrace:

res = torch_least_squares(lambda x: torch_cost(x, gt_depth_map, renderer_depth, dummy_scale), x)
File "<paths>/torchmin/lstsq/least_squares.py", line 271, in least_squares
result = trf(fun_wrapped, x0, f0, lb, ub, ftol, xtol, gtol,
File "<paths>/torchmin/lstsq/trf.py", line 25, in trf
return trf_no_bounds(
File "<paths>/torchmin/lstsq/trf.py", line 108, in trf_no_bounds
gn_h = lsmr(J_h, f, damp=damp_full, **tr_options)[0]
File "<paths>/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "<paths>/torchmin/lstsq/lsmr.py", line 178, in lsmr
u.mul_(-alpha).add_(A.matvec(v))
File "<paths>/torchmin/lstsq/linear_operator.py", line 46, in matvec
return self._matvec(x)
File "<paths>/torchmin/lstsq/common.py", line 162, in <lambda>
matvec=lambda x: J.matvec(x.view(-1) * d),
File "<paths>/torchmin/lstsq/linear_operator.py", line 46, in matvec
return self._matvec(x)
File "<paths>/torchmin/lstsq/linear_operator.py", line 28, in jvp
jvp, = autograd.grad(gx, gf, v, retain_graph=True)
File "<paths>/torch/autograd/__init__.py", line 394, in grad
result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

This error does not happen when I use the minimize function but then again minimize fails to actually minimize anything For comparison the equivalent scipy code:

def cost(weights, gt_depth_map, renderer, scale):
    pred_mesh = get_mesh(weights, scale=scale, use_reduced_space=True)
    pred_depth_map = renderer(pred_mesh).numpy()[0, ..., 0]
    return np.square(pred_depth_map - gt_depth_map).reshape(-1)

res = least_squares(cost, np.zeros(n_components), args=(gt_depth_map, renderer_depth, dummy_scale))
rfeinman commented 6 months ago

Hi @bfialkoff

It sounds like this is a problem with your pytorch objective function. Have you checked the output of your function to make sure it has a .grad_fn and that the gradient correctly flows back to your input variable? It seems like the input variable is not being used in the graph. Whereas least_squares will raise an error if there are unused variables, minimize allows unused variables but leaves it up to the user to debug issues in the gradient graph.

As a sanity check of your pytorch function I would first try optimizing the function using a standard pytorch optimizer (e.g. torch.optim.Adam). If this works without error then we would need to look into it a bit more.

PS - the function passed to least_squares is expected to return the residual vector, not the squared residual vector. Would update this if I were you.

bfialkoff commented 6 months ago

Thanks for the detailed reply. I first tried with Adam (or rather SGD) and it worked, as in it didn't crash, but it totally fails to converge which I think may be another indication that there is an issue with gradient flow, but as far as I can tell everything seems ok. I stepped through the code and check requires_grad and grad_fn for every step in the chain. See below. Does anything look off?

from torchmin import least_squares as torch_least_squares

def get_mesh(self, weights: torch.Tensor, scale=1.):
    # upon entry weights has grad_fn=<IndexPutBackward0>

    weights = torch.atleast_2d(weights)
    # Now weights has <UnsqueezeBackward0>

    num_meshes = weights.shape[0]
    shape_space = self.vector_space_basis.type(weights.dtype)
    assert self.vector_space_basis.shape[0] == weights.shape[1]

    displacement_vertices = torch.mm(weights, vector_space_basis).view(weights.shape[0], -1, 3)
    # displacement_vertices has grad_fn=<ViewBackward0>

    new_veritices = self.average_mesh_torch.verts_packed().unsqueeze(0) + displacement_vertices
    # new_veritices has grad_fn=<AddBackward0 at 0x2cb9e8e80>

    out_mesh = Meshes(verts=new_veritices * scale, faces=torch.cat(num_meshes * [self.average_mesh_torch.faces_packed().unsqueeze(0)]))
    # out scores new_veritices * scale in its verts_packed() packed, which has grad_fn= <CatBackward0>
    return out_mesh

 def torch_cost(weights, gt_depth_map, renderer, scale):
    # weights on entry has grad_fn=<IndexPutBackward0>
    pred_mesh = get_mesh_torch(weights, scale=scale, use_reduced_space=True)
    # pred_mesh.verts_packed() as above has <CatBackward0 at 0x2cba31a90>

    pred_depth_map = renderer(pred_mesh)
    # pred_depth_map has grad_fn=<UnsqueezeBackward0>

    cost = (pred_depth_map - gt_depth_map).view(-1)
    # cost has grad_fn=<ViewBackward0>
    return cost

x = torch.zeros(n_components, requires_grad=True)
res = torch_least_squares(lambda x: torch_cost(x, gt_depth_map, renderer_depth, dummy_scale), x)

For extra verbosity, I also have the computation graph (which is way more complex than i would have expected). output_graph

rfeinman commented 6 months ago

Hi - I don't know much about pytorch3d or the ops involved in mesh generation/rendering so I can't help there. But as a sanity check have you looked at x.grad after doing a forward + backward pass of the objective? See code snippet below. You should run this and make sure that x.grad is not zeros or nans, etc.

x = torch.zeros(n_components, requires_grad=True)
residual = torch_cost(x, gt_depth_map, renderer_depth, dummy_scale)
residual.pow(2).sum().backward()
print(x.grad)
bfialkoff commented 6 months ago

Thanks for the engagement, I appreciate it :) I of course already tried and checked this

r = torch_cost(x, gt_depth_map, renderer_depth, dummy_scale)
r.pow(2).sum().backward()
print(x.grad)
<<< tensor([8.6832e+13, 2.4886e+13, 1.9570e+13, 3.4736e+13, 2.6175e+13])

but if i run it twice...

rn = torch_cost(x+x.grad, gt_depth_map, renderer_depth, dummy_scale)
rn.pow(2).sum().backward()
print(x.grad)
<<< tensor([nan, nan, nan, nan, nan])

I guess next step is to figure out what happens to the gradients. Didn't think to check them twice. Any idea why that happens on the 2nd pass? I'll have to check further but I'm pretty sure the scipy least_squares would also fail if things go nan, its not clear to me if this is an issue in the implementation, or just something wrong with my usage/cost function Thank you :)

rfeinman commented 6 months ago

Well as you can see the first derivative is massive so I am guessing that the second derivative is so large it is beyond the range of fp32 :)