teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

possible second derivative issue? #69

Closed SingletC closed 1 year ago

SingletC commented 1 year ago

Hi, I appreciate the author shared this torch version of softsort. I have some applications that require calculate the hessian matrix. I run into the problem of doing so. here is the minimum code that reproduces the error.

import torch
def compute_hessian(inputs, outputs):
    inputs.requires_grad_(True)
    grads = torch.autograd.grad(outputs, inputs, create_graph=True)[0]
    hessian_rows = []
    for grad_elem in grads:
        hessian_row = torch.autograd.grad(grad_elem, inputs, retain_graph=True)[0]
        hessian_rows.append(hessian_row)
    hessian = torch.stack(hessian_rows)
    return hessian
def vaild_function(inputs):
    sort, = torch.sort(inputs)
    return sort[0]*sort[1]+sort[0]**2+sort[1]**2
inputs = torch.tensor([1.0, 2.0], requires_grad=True)
outputs = target_function(inputs)
hessian = compute_hessian(inputs, outputs)

I will got

tensor([[0., 1.],
        [1., 0.]])

Then I try softsort

def sort_function(inputs):
    sort = torchsort.soft_sort(inputs.reshape(1,-1))[0]
    return sort[0]*sort[1]+sort[0]**2+sort[1]**2
outputs_ = sort_function(inputs)
hessian_ = compute_hessian(inputs, outputs_)

I got

  File "<ipython-input-32-eb691f20f64c>", line 27, in <module>
    hessian = compute_hessian(inputs, outputs_)
  File "<ipython-input-32-eb691f20f64c>", line 9, in compute_hessian
    hessian_row = torch.autograd.grad(grad_elem, inputs, retain_graph=True)[0]
  File "soft/miniconda3/envs/madgp/lib/python3.9/site-packages/torch/autograd/__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Here are my packages' version:

pip freeze |grep torch
torch==2.0.1+cpu
torchaudio==2.0.2+cpu
torchsort==0.1.9
torchvision==0.15.2+cpu

I appreciate any thoughts to confirm this is an issue with Torchsort or I did something wrong

teddykoker commented 1 year ago

Hi @SingletC, currently second derivatives are not supported with this library. Since it is written as a C++, CUDA extension, pytorch does not support "Double Backwards" (see when backward is not tracked. One could write a custom backward function for the existing backward function to enable higher order derivatives, but this could be quite difficult and would not have the bandwidth to work on this. I believe you would have the same issue with the original code base as the operations are written in plain numpy, but there may be some other solutions in the recent years if you look around.