Closed mlaves closed 1 year ago
When using batched inputs with varying ktraj sizes, I get a TorchScript RuntimeError when using a cuda GPU. On CPU, the error does not occur. See following example:
import torch import torchkbnufft as tkbn torch.manual_seed(0) device = torch.device("cuda:0") batch_size = 16 x = torch.rand((256, 256)) im_size = x.shape x = x.unsqueeze(0).unsqueeze(0).to(torch.complex64) x = x.repeat(batch_size, 1, 1, 1).to(device) for _ in range(10): klength = 32 + torch.randint(0, 128, size=(1,)).item() print(klength) ktraj = torch.stack( (torch.zeros(klength), torch.linspace(-torch.pi, torch.pi, klength)) ) ktraj = ktraj.to(device) ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1) nufft_ob = tkbn.KbNufft(im_size=im_size).to(device) kdata = nufft_ob(x, ktraj)
Output for me is
104 138 88 61 --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[2], line 21 18 ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1) 20 nufft_ob = tkbn.KbNufft(im_size=im_size).to(device) ---> 21 kdata = nufft_ob(x, ktraj) File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs) 1106 # If we don't have any hooks, we want to skip the rest of the logic in 1107 # this function, and just call forward. 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], [] File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/modules/kbnufft.py:211, in KbNufft.forward(self, image, omega, interp_mats, smaps, norm) 208 assert isinstance(self.table_oversamp, Tensor) 209 assert isinstance(self.offsets, Tensor) --> 211 output = tkbnF.kb_table_nufft( 212 image=image, 213 scaling_coef=self.scaling_coef, 214 im_size=self.im_size, 215 grid_size=self.grid_size, 216 omega=omega, 217 tables=tables, 218 n_shift=self.n_shift, 219 numpoints=self.numpoints, 220 table_oversamp=self.table_oversamp, 221 offsets=self.offsets.to(torch.long), 222 norm=norm, 223 ) 225 if not is_complex: 226 output = torch.view_as_real(output) File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/functional/nufft.py:172, in kb_table_nufft(image, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm) 169 is_complex = False 170 image = torch.view_as_complex(image) --> 172 data = kb_table_interp( 173 image=fft_and_scale( 174 image=image, 175 scaling_coef=scaling_coef, 176 im_size=im_size, 177 grid_size=grid_size, 178 norm=norm, 179 ), 180 omega=omega, 181 tables=tables, 182 n_shift=n_shift, 183 numpoints=numpoints, 184 table_oversamp=table_oversamp, 185 offsets=offsets, 186 ) 188 if is_complex is False: 189 data = torch.view_as_real(data) File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/functional/interp.py:117, in kb_table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets) 114 is_complex = False 115 image = torch.view_as_complex(image) --> 117 data = KbTableInterpForward.apply( 118 image, omega, tables, n_shift, numpoints, table_oversamp, offsets 119 ) 121 if is_complex is False: 122 data = torch.view_as_real(data) File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_autograd/interp.py:87, in KbTableInterpForward.forward(ctx, image, omega, tables, n_shift, numpoints, table_oversamp, offsets) 81 """Apply table interpolation. 82 83 This is a wrapper for for PyTorch autograd. 84 """ 85 grid_size = torch.tensor(image.shape[2:], device=image.device) ---> 87 output = table_interp( 88 image=image, 89 omega=omega, 90 tables=tables, 91 n_shift=n_shift, 92 numpoints=numpoints, 93 table_oversamp=table_oversamp, 94 offsets=offsets, 95 ) 97 ctx.save_for_backward( 98 omega, n_shift, numpoints, table_oversamp, offsets, grid_size, *tables 99 ) 101 return output File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_nufft/interp.py:376, in table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets, min_kspace_per_fork) 374 if USING_OMP and image.device == torch.device("cpu"): 375 torch.set_num_threads(threads_per_fork) --> 376 kdat = table_interp_fork_over_batchdim( 377 image, omega, tables, n_shift, numpoints, table_oversamp, offsets, num_forks 378 ) 379 if USING_OMP and image.device == torch.device("cpu"): 380 torch.set_num_threads(num_threads) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File "/home/CODE1/320171775/miniforge3/envs/py310/lib/python3.10/site-packages/torchkbnufft/_nufft/interp.py", line 265, in table_interp_fork_over_batchdim # collect the results return torch.cat([torch.jit.wait(future) for future in futures]) ~~~~~~~~~~~~~~ <--- HERE RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: MALFORMED INPUT: Index out of bounds in check_bounds. Index: 511; bounds: [0, 1). - ret_val(dtype=int64_t, sizes=[1], strides=[1])
After updating torch to 1.13.1 and torchkbnufft to 1.4.0, the problem was fixed.
torch
torchkbnufft
@mlaves thanks so much for posting the fix!
When using batched inputs with varying ktraj sizes, I get a TorchScript RuntimeError when using a cuda GPU. On CPU, the error does not occur. See following example:
Output for me is