mmuckley / torchkbnufft

A high-level, easy-to-deploy non-uniform Fast Fourier Transform in PyTorch.
https://torchkbnufft.readthedocs.io/
MIT License
209 stars 44 forks source link

KB-NUFFT hangs when used in PyTorch DataLoader with num_workers > 0 #73

Open maartenterpstra opened 1 year ago

maartenterpstra commented 1 year ago

Hi,

I'm trying to perform on-the-fly data undersampling in my PyTorch dataset. To do this, I perform a Toeplitz NUFFT in the __getitem__ function of my Dataset class. This works as expected. Now, I want to to batching, so I wrap the PyTorch Dataset in a PyTorch DataLoader. This works as expected when num_workers=0. However, when num_workers is non-zero, computation of the NUFFT seemingly enters an infinite loop.

Expected behaviour

Performing a NUFFT in parallel using multiple workers should result in undersampled images.

Observed behaviour

Sampling the dataloader results in a hanging script, seemingly entering an infinite loop.

Extra information

Minimal example

import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import rescale
import numpy as np
import torchkbnufft as tkbn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

NUM_WORKERS=0 # set to > 0 to hang

class DataUndersampler(Dataset):
    def __init__(self, undersampling_factor, spatial_scaling_factor):
        self.image = shepp_logan_phantom()
        self.image_shape = self.image.shape
        self.undersampling_factor = undersampling_factor
        self.spatial_scaling_factor = spatial_scaling_factor
        # Need the original size to determine the undersampling factor and NUFFT operators
        self.orig_size = self.image_shape
        self.image_shape = (self.image_shape[0] // self.spatial_scaling_factor, self.image_shape[1] // self.spatial_scaling_factor)
        # Create an oversampled grid
        spokelength = self.image_shape[0] * 2
        self.grid_size = (spokelength, spokelength)

        # Generate A LOT of spokes, pick a starting point at random
        nspokes = 2000

        # Sample enough spokes to achieve undersampling factor
        self.spokes_to_sample = int((self.orig_size[0] * np.pi / 2) / self.undersampling_factor)
        # Generate a golden angle radial trajectory
        ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
        kx = np.zeros(shape=(spokelength, nspokes))
        ky = np.zeros(shape=(spokelength, nspokes))
        ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
        for i in range(1, nspokes):
            kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
            ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]

        self.ky = np.transpose(ky)
        self.kx = np.transpose(kx)
        # 1D Ramlak. Needed for density compensation. Density is a linear function 
        # depending on the distance of the center of k-space...
        ram_lak = np.abs(np.linspace(-1, 1, spokelength + 1))
        ram_lak = ram_lak[:-1]

        #... except for the center, we know exactly how often we sample that,
        # namely as many times as the number of spokes
        middle_idx = len(ram_lak) // 2
        ram_lak[middle_idx] = 1/(2 * self.spokes_to_sample)
        self.ram_lak = ram_lak
    def __len__(self):
        return 1
    def __getitem__(self, index):
        if self.spatial_scaling_factor > 1:
            img = rescale(self.image, 1/self.spatial_scaling_factor).astype(np.complex)
        else:
            img = self.image.astype(np.complex)

        if self.undersampling_factor > 1:
            img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)

            toep_ob = tkbn.ToepNufft()
            # We pick a random starting spoke
            offset = np.random.choice(range(self.ky.shape[0]-self.spokes_to_sample))

            # And select as many subsequent spokes to reach the desired undersampling factor
            # todo: use continuing trajectories for cine?
            selected_ky = self.ky[offset:offset+self.spokes_to_sample].flatten()
            selected_kx = self.kx[offset:offset+self.spokes_to_sample].flatten()

            ktraj = torch.tensor(np.stack((selected_ky, selected_kx), axis=0))
            # Repeat and reshape the ram-lak so every spoke is density-compensated
            ram_lak_t = torch.from_numpy(np.tile(self.ram_lak, self.spokes_to_sample)).unsqueeze(0).unsqueeze(0)
            # Calculate the really efficient Toeplitz kernel to compute the NUFFT
            dcomp_kernel = tkbn.calc_toeplitz_kernel(ktraj, self.image_shape, weights=ram_lak_t, norm='ortho',numpoints=(3,3))  # with density compensation
            # And in a single step, compute the radial k-space and back to image space
            img = toep_ob(img_tensor, dcomp_kernel, norm='ortho').abs().squeeze().numpy()
            # renormalize the output, because undersampling can change this.
            img /= np.max(img)

        return np.abs(img)

# Create dataset
dset = DataUndersampler(
            undersampling_factor=2, 
            spatial_scaling_factor=1)

# this works fine
undersampled_img = dset[0]

# From here, the observed behaviour emerges. 

dloader = DataLoader(dset,
            shuffle=False,
            batch_size=1, num_workers=NUM_WORKERS)

# this statement hangs when num_workers > 0
undersampled_img = next(iter(dloader))
mmuckley commented 1 year ago

Hello @maartenterpstra, I think the issue is due to the table-based NUFFT, which is used inside tkbn.calc_toeplitz_kernel.

Does every sample have a different trajectory? If they're all the same, you could apply NUFFT outside the dataloader.

maartenterpstra commented 1 year ago

Hi @mmuckley. I was also thinking that as a workaround I could compute the NUFFT for a single batch outside the dataloader. In general, every sample has a different trajectory but the same number of spokes. Would this be possible?

mmuckley commented 1 year ago

Hello @maartenterpstra, it may be more efficient to loop over the list or use a batched NUFFT. The batched NUFFT is good for a large number of small NUFFTs. You can see how to use it here.

I also opened #74 as a potential enhancement with a pointer to where the code controls threading if you'd be interested in that route.