mmuckley / torchkbnufft

A high-level, easy-to-deploy non-uniform Fast Fourier Transform in PyTorch.
https://torchkbnufft.readthedocs.io/
MIT License
209 stars 44 forks source link

Batched input with varying size leads to TorchScript error #89

Closed mlaves closed 1 year ago

mlaves commented 1 year ago

When using batched inputs with varying ktraj sizes, I get a TorchScript RuntimeError when using a cuda GPU. On CPU, the error does not occur. See following example:

import torch
import torchkbnufft as tkbn
torch.manual_seed(0)

device = torch.device("cuda:0")

batch_size = 16
x = torch.rand((256, 256))
im_size = x.shape
x = x.unsqueeze(0).unsqueeze(0).to(torch.complex64)
x = x.repeat(batch_size, 1, 1, 1).to(device)

for _ in range(10):
    klength = 32 + torch.randint(0, 128, size=(1,)).item()
    print(klength)
    ktraj = torch.stack(
        (torch.zeros(klength), torch.linspace(-torch.pi, torch.pi, klength))
    )
    ktraj = ktraj.to(device)
    ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)

    nufft_ob = tkbn.KbNufft(im_size=im_size).to(device)
    kdata = nufft_ob(x, ktraj)

Output for me is

104
138
88
61
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 21
     18 ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)
     20 nufft_ob = tkbn.KbNufft(im_size=im_size).to(device)
---> 21 kdata = nufft_ob(x, ktraj)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/modules/kbnufft.py:211, in KbNufft.forward(self, image, omega, interp_mats, smaps, norm)
    208     assert isinstance(self.table_oversamp, Tensor)
    209     assert isinstance(self.offsets, Tensor)
--> 211     output = tkbnF.kb_table_nufft(
    212         image=image,
    213         scaling_coef=self.scaling_coef,
    214         im_size=self.im_size,
    215         grid_size=self.grid_size,
    216         omega=omega,
    217         tables=tables,
    218         n_shift=self.n_shift,
    219         numpoints=self.numpoints,
    220         table_oversamp=self.table_oversamp,
    221         offsets=self.offsets.to(torch.long),
    222         norm=norm,
    223     )
    225 if not is_complex:
    226     output = torch.view_as_real(output)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/functional/nufft.py:172, in kb_table_nufft(image, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm)
    169     is_complex = False
    170     image = torch.view_as_complex(image)
--> 172 data = kb_table_interp(
    173     image=fft_and_scale(
    174         image=image,
    175         scaling_coef=scaling_coef,
    176         im_size=im_size,
    177         grid_size=grid_size,
    178         norm=norm,
    179     ),
    180     omega=omega,
    181     tables=tables,
    182     n_shift=n_shift,
    183     numpoints=numpoints,
    184     table_oversamp=table_oversamp,
    185     offsets=offsets,
    186 )
    188 if is_complex is False:
    189     data = torch.view_as_real(data)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/functional/interp.py:117, in kb_table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets)
    114     is_complex = False
    115     image = torch.view_as_complex(image)
--> 117 data = KbTableInterpForward.apply(
    118     image, omega, tables, n_shift, numpoints, table_oversamp, offsets
    119 )
    121 if is_complex is False:
    122     data = torch.view_as_real(data)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_autograd/interp.py:87, in KbTableInterpForward.forward(ctx, image, omega, tables, n_shift, numpoints, table_oversamp, offsets)
     81 """Apply table interpolation.
     82 
     83 This is a wrapper for for PyTorch autograd.
     84 """
     85 grid_size = torch.tensor(image.shape[2:], device=image.device)
---> 87 output = table_interp(
     88     image=image,
     89     omega=omega,
     90     tables=tables,
     91     n_shift=n_shift,
     92     numpoints=numpoints,
     93     table_oversamp=table_oversamp,
     94     offsets=offsets,
     95 )
     97 ctx.save_for_backward(
     98     omega, n_shift, numpoints, table_oversamp, offsets, grid_size, *tables
     99 )
    101 return output

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_nufft/interp.py:376, in table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets, min_kspace_per_fork)
    374 if USING_OMP and image.device == torch.device("cpu"):
    375     torch.set_num_threads(threads_per_fork)
--> 376 kdat = table_interp_fork_over_batchdim(
    377     image, omega, tables, n_shift, numpoints, table_oversamp, offsets, num_forks
    378 )
    379 if USING_OMP and image.device == torch.device("cpu"):
    380     torch.set_num_threads(num_threads)

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/CODE1/320171775/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_nufft/interp.py", line 265, in table_interp_fork_over_batchdim

    # collect the results
    return torch.cat([torch.jit.wait(future) for future in futures])
                      ~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: MALFORMED INPUT: Index out of bounds in check_bounds. Index: 511; bounds: [0, 1). - ret_val(dtype=int64_t, sizes=[1], strides=[1])
mlaves commented 1 year ago

After updating torch to 1.13.1 and torchkbnufft to 1.4.0, the problem was fixed.

mmuckley commented 1 year ago

@mlaves thanks so much for posting the fix!