rfeinman / pytorch-minimize

Newton and Quasi-Newton optimization with PyTorch
https://pytorch-minimize.readthedocs.io
MIT License
308 stars 34 forks source link

Using _vmap in PyTorch to compute the Hessian-vector product (hvp) encounters a runtime error #33

Open bfialkoff opened 7 months ago

bfialkoff commented 7 months ago

Trying to use the minimize function with methods

But succeeds with the other methods. Presumably, the other methods aren't computing Hessians.

I tracked the error to here.

I can't quite figure out what is really responsible for the error but I suspect its _vmap failing to batch properly because my debugger indicates there is something wrong with the tensors that are yielded. If i look at the batched_inputs[0] variable in _vmap_internals._vmap and try print it, or view it, or do +1 to it then i get the error RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.

Computing the hessian in a loop works but is hideous and slow.

hvp_map = lambda V: torch.stack(
                [autograd.grad(grad, x, v, retain_graph=True)[0] for v in V], dim=0)
hess = hvp_map(self._I)

Is this a real issue or am missing something?

rfeinman commented 6 months ago

At this point, now that PyTorch 2 is out we should remove all uses of the _vmap utility and replace them with forward-mode automatic differentiation. Jacobian & Hessian computation will be much more efficient.

I have added your ticket to the PyTorch 2 milestone. Help appreciated!