Closed bfialkoff closed 4 months ago
Hi @bfialkoff
It sounds like this is a problem with your pytorch objective function. Have you checked the output of your function to make sure it has a .grad_fn
and that the gradient correctly flows back to your input variable? It seems like the input variable is not being used in the graph. Whereas least_squares
will raise an error if there are unused variables, minimize
allows unused variables but leaves it up to the user to debug issues in the gradient graph.
As a sanity check of your pytorch function I would first try optimizing the function using a standard pytorch optimizer (e.g. torch.optim.Adam
). If this works without error then we would need to look into it a bit more.
PS - the function passed to least_squares
is expected to return the residual vector, not the squared residual vector. Would update this if I were you.
Thanks for the detailed reply.
I first tried with Adam (or rather SGD) and it worked, as in it didn't crash, but it totally fails to converge which I think may be another indication that there is an issue with gradient flow, but as far as I can tell everything seems ok. I stepped through the code and check requires_grad
and grad_fn
for every step in the chain. See below. Does anything look off?
from torchmin import least_squares as torch_least_squares
def get_mesh(self, weights: torch.Tensor, scale=1.):
# upon entry weights has grad_fn=<IndexPutBackward0>
weights = torch.atleast_2d(weights)
# Now weights has <UnsqueezeBackward0>
num_meshes = weights.shape[0]
shape_space = self.vector_space_basis.type(weights.dtype)
assert self.vector_space_basis.shape[0] == weights.shape[1]
displacement_vertices = torch.mm(weights, vector_space_basis).view(weights.shape[0], -1, 3)
# displacement_vertices has grad_fn=<ViewBackward0>
new_veritices = self.average_mesh_torch.verts_packed().unsqueeze(0) + displacement_vertices
# new_veritices has grad_fn=<AddBackward0 at 0x2cb9e8e80>
out_mesh = Meshes(verts=new_veritices * scale, faces=torch.cat(num_meshes * [self.average_mesh_torch.faces_packed().unsqueeze(0)]))
# out scores new_veritices * scale in its verts_packed() packed, which has grad_fn= <CatBackward0>
return out_mesh
def torch_cost(weights, gt_depth_map, renderer, scale):
# weights on entry has grad_fn=<IndexPutBackward0>
pred_mesh = get_mesh_torch(weights, scale=scale, use_reduced_space=True)
# pred_mesh.verts_packed() as above has <CatBackward0 at 0x2cba31a90>
pred_depth_map = renderer(pred_mesh)
# pred_depth_map has grad_fn=<UnsqueezeBackward0>
cost = (pred_depth_map - gt_depth_map).view(-1)
# cost has grad_fn=<ViewBackward0>
return cost
x = torch.zeros(n_components, requires_grad=True)
res = torch_least_squares(lambda x: torch_cost(x, gt_depth_map, renderer_depth, dummy_scale), x)
For extra verbosity, I also have the computation graph (which is way more complex than i would have expected).
Hi - I don't know much about pytorch3d or the ops involved in mesh generation/rendering so I can't help there. But as a sanity check have you looked at x.grad
after doing a forward + backward pass of the objective? See code snippet below. You should run this and make sure that x.grad
is not zeros or nans, etc.
x = torch.zeros(n_components, requires_grad=True)
residual = torch_cost(x, gt_depth_map, renderer_depth, dummy_scale)
residual.pow(2).sum().backward()
print(x.grad)
Thanks for the engagement, I appreciate it :) I of course already tried and checked this
r = torch_cost(x, gt_depth_map, renderer_depth, dummy_scale)
r.pow(2).sum().backward()
print(x.grad)
<<< tensor([8.6832e+13, 2.4886e+13, 1.9570e+13, 3.4736e+13, 2.6175e+13])
but if i run it twice...
rn = torch_cost(x+x.grad, gt_depth_map, renderer_depth, dummy_scale)
rn.pow(2).sum().backward()
print(x.grad)
<<< tensor([nan, nan, nan, nan, nan])
I guess next step is to figure out what happens to the gradients. Didn't think to check them twice. Any idea why that happens on the 2nd pass? I'll have to check further but I'm pretty sure the scipy least_squares would also fail if things go nan, its not clear to me if this is an issue in the implementation, or just something wrong with my usage/cost function Thank you :)
Well as you can see the first derivative is massive so I am guessing that the second derivative is so large it is beyond the range of fp32 :)
I am using this project in conjunction with pytorch3d. I initially used scipy's least_squares and it worked fairly well. I decided to use torch_min.least_squares to take advantage of torch's autograd which should work with pytorch3d's differential renderer.
Basically, I have a function that returns a 3d Mesh that is linearly-dependant and a weight vector. The idea is to find a vector that produces a mesh, whose depth map matches a ground truth depth map. Basically: $\underset{weights}{argmin} ||gt \textunderscore depth \textunderscore map - render \textunderscore depth(get \textunderscore mesh(weights))||$
Here is my code:
This fails with the following stacktrace:
This error does not happen when I use the
minimize
function but then againminimize
fails to actually minimize anything For comparison the equivalent scipy code: