PTB-MR / mrpro

MR image reconstruction and processing.
https://ptb-mr.github.io/mrpro/
Apache License 2.0
13 stars 2 forks source link

POCS for partial fourier #323

Open fzimmermann89 opened 3 months ago

fzimmermann89 commented 3 months ago

At some point we might want to have a POCS partial fourier filling. Even if we use an iterative reconstruction algorithm, this can be used to provide a better x0 as far as I know: It uses the assumption that the phase changes only slowly can can be estimated using the fully sampled center region.

fzimmermann89 commented 3 months ago

This could be a possible start point:


#(c) Felix Zimmermann (fzimmermann89@gmail.com), BSD 2-Clause.
import torch

def find_symmetric_region(kdata: torch.Tensor) -> tuple[int, slice]:
    """
    Determine the symmetric region indices around the center of k-space in y direction

    Parameters
    ----------
    kdata : torch.Tensor
        Input k-space data tensor of shape (..., Ny, Nx).

    Returns
    -------
    tuple
        Center index (Ny // 2) and slice object defining the symmetric region.
    """
    Ny = kdata.shape[-2]
    center = Ny // 2
    first_nonzero = torch.nonzero(kdata.abs().sum((-3, -1)) > 0)[0].item()
    n_sym_center = 2 * (Ny - first_nonzero)
    sym_end = min(Ny, center + n_sym_center // 2)
    sym_start = max(0, sym_end - n_sym_center)
    return center, slice(sym_start, sym_end)

def pocs(kdata: torch.Tensor, n_iterations: int = 20, smooth_transition:bool=True) -> torch.Tensor:   
    """
    Perform Projection onto Convex Sets (POCS) algorithm to reconstruct missing k-space data.

    The missing data is assumed to be a zero-filled region at the beginning of the y-direction. 

    Parameters
    ----------
    kdata
        Input k-space data tensor of shape (..., Ny, Nx).
    n_iterations
        Number of iterations for the POCS algorithm.
    smooth_transition
        If true, apply smooth transition between known and reconstructed data.

    Returns
    -------
    torch.Tensor
        Reconstructed k-space data tensor of the same shape as input kdata.
    """ 
    # Find size of symmetric sampled center region
    center, sym_region = find_symmetric_region(kdata)
    n_sym_center = sym_region.stop - sym_region.start

    # Create low-res image and extract phase
    kdata_lowres = torch.zeros_like(kdata)
    kdata_lowres[:, sym_region, :] = kdata[:, sym_region, :] * torch.hamming_window(n_sym_center)[:, None]
    angle = torch.angle(torch.fft.ifftn(kdata_lowres, dim=(-2, -1), norm='ortho'))

    # Create mask for known data
    mask = (kdata != 0)

    # Iterate
    kdata_current = kdata.clone()
    for _ in range(n_iterations):
        img = torch.fft.ifftn(kdata_current, dim=(-2, -1), norm='ortho')
        img = torch.polar(img.abs(), angle)
        kdata_new = torch.fft.fftn(img, dim=(-2, -1), norm='ortho')
        kdata_current = torch.where(mask, kdata, kdata_new)

    # Smooth transition between data and reconstructed data
    if smooth_transition:
      n_transition = (n_sym_center - 1) // 3
      transition_window = torch.hann_window(n_transition)[:, None]
      trans_slice = slice(center - n_transition, center)

      kdata_current[:, trans_slice, :] = (
          kdata[:, trans_slice, :] * transition_window +
          kdata_new[:, trans_slice, :] * (1 - transition_window)
      )

    return kdata_current