mmuckley / torchkbnufft

A high-level, easy-to-deploy non-uniform Fast Fourier Transform in PyTorch.
https://torchkbnufft.readthedocs.io/
MIT License
204 stars 44 forks source link

Higher order derivatives involving nufft? #33

Closed ajlok3 closed 2 years ago

ajlok3 commented 3 years ago

Hi,

My problem is quite specific but I try my best to describe it here:

I have two loss functions - let's call them loss1 and loss2. loss1 is computed from an input image x, applying nufft operator on it and then computing loss1, i.e. f(x) := loss1(nufft(x)) (more operations are involved but they are irrelevant here).

Subsequently, I call torch.autograd.grad on f to compute loss2, i.e. f'(x) is part of the backpropagation graph for loss2. However, when I try to execute my code, sometimes, I get the following error message in the line where the autograd.grad function is called:

grafik

Frankly, the setup is quite complicated but it works fine when I remove the nufft operator from the function f. When the above error doesn't appear (usually, when I restart the kernel in Jupyter), the code breaks only at the final point when backward() on loss2 is called with an even more cryptic message:

grafik Any suggestions what the underlying problem might be? Happy to provide more information if needed.

Thanks and best, ajlok3

ajlok3 commented 3 years ago

So generally speaking, the issue would be: Is is possible to build higher order derivatives with nufft?

mmuckley commented 2 years ago

Hello @ajlok3, I hadn't tested it with higher-order derivatives, but you might be able to take your old gradients and re-backprop with them by manually calling the grad function.

In terms of the error we definitely do have an in-place operation. I'm guessing one of our cached tensors got its requires_grad turned on during the first backprop and so we get the leaf variable error. Unfortunately, I don't have a ton of time to work on it at the moment. The backend is largely in interp.py if you have a proposed solution.

mmuckley commented 2 years ago

Hello @ajlok3, I wrote the following test script and encountered no errors. Could you let me know what the difference is between this and your setup?

import numpy as np
import torch
import torchkbnufft as tkbn

def create_input_plus_noise(shape, is_complex):
    x = np.arange(np.product(shape)).reshape(shape)
    x = torch.tensor(x, dtype=torch.get_default_dtype())

    if is_complex:
        x = x + torch.randn(size=x.shape) + 1j * torch.randn(size=x.shape)
    else:
        x = x + torch.randn(size=x.shape)

    return x

def create_ktraj(ndims, klength):
    return torch.rand(size=(ndims, klength)) * 2 * np.pi - np.pi

def main():
    is_complex = True
    shape = [1, 4, 32, 16]
    kdata_shape = [1, 4, 83]

    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]
    im_shape = [s for s in shape]
    im_shape[1] = 1

    image = create_input_plus_noise(im_shape, is_complex)
    kdata = create_input_plus_noise(kdata_shape, is_complex)
    smaps = create_input_plus_noise(shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])
    image.requires_grad = True
    kdata.requires_grad = True

    forw_ob = tkbn.KbNufft(im_size=im_size)
    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)

    image_forw = forw_ob(image, ktraj, smaps=smaps)

    loss1 = (torch.abs(image_forw) ** 2 / 2).sum()
    loss2 = torch.abs(image_forw - torch.zeros_like(image_forw)).sum()

    loss1.backward(retain_graph=True)
    loss1.backward(retain_graph=True)
    loss2.backward()

    print("success!")

if __name__ == "__main__":
    main()
mmuckley commented 2 years ago

Closing due to inactivity - reopen if there are further issues.