rfeinman / pytorch-minimize

Newton and Quasi-Newton optimization with PyTorch
https://pytorch-minimize.readthedocs.io
MIT License
292 stars 34 forks source link

Analytic Gradient/Hessian #3

Closed fzimmermann89 closed 2 years ago

fzimmermann89 commented 2 years ago

Hi,

just a question: Do you know if there is a simple way to provide a callable for the gradient and/or Hessian (for example, for analytical expressions if they are known)

Thank you in advance! Felix

rfeinman commented 2 years ago

Hi @fzimmermann89 -

At the moment it would be a considerable refactor to provide this option. Furthermore, the option should not be necessary: PyTorch computes real analytic gradient on the backend, so you'll get the exact same result with the current package.

If you'd like to provide a callable for the gradient & hessian then I'd suggest using SciPy's optimize module, which has an extensive repertoire of tools for this purpose. The motivation for pytorch-minimize was to provide a "hands-free" alternative to these tools.

rfeinman commented 2 years ago

Hi @fzimmermann89,

I thought of an easy work-around for you if you're set on using your own gradient function as opposed to pytorch's automatic gradient. Simply use the method below to create a custom objective for torchmin.minimize.

def make_func(fn, grad_fn):
    """Trick to make a custom autograd function"""
    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            return fn(x)

        @staticmethod
        def backward(ctx, grad_output):
            x, = ctx.saved_tensors
            if grad_output.numel() == 1:
                return grad_fn(x) * grad_output
            else:
                return grad_fn(x) @ grad_output

    return Function.apply

Here is a little demo:

fn = make_func(my_fn, my_grad_fn)

result = torchmin.minimize(fn, torch.tensor([1., 8.]), method='bfgs')

You can use a custom hessian function by performing this pre-processing step in a nested fashion:

fn = make_func(my_fn, make_func(my_grad_fn, my_hess_fn))

For the inner call, a matrix-vector op (grad_fn(x) @ grad_output) is used instead of vector-scalar (grad_fn(x) * grad_output) inside of Function.backward().