Open fzimmermann89 opened 3 months ago
This could be a possible start point:
#(c) Felix Zimmermann (fzimmermann89@gmail.com), BSD 2-Clause.
import torch
def find_symmetric_region(kdata: torch.Tensor) -> tuple[int, slice]:
"""
Determine the symmetric region indices around the center of k-space in y direction
Parameters
----------
kdata : torch.Tensor
Input k-space data tensor of shape (..., Ny, Nx).
Returns
-------
tuple
Center index (Ny // 2) and slice object defining the symmetric region.
"""
Ny = kdata.shape[-2]
center = Ny // 2
first_nonzero = torch.nonzero(kdata.abs().sum((-3, -1)) > 0)[0].item()
n_sym_center = 2 * (Ny - first_nonzero)
sym_end = min(Ny, center + n_sym_center // 2)
sym_start = max(0, sym_end - n_sym_center)
return center, slice(sym_start, sym_end)
def pocs(kdata: torch.Tensor, n_iterations: int = 20, smooth_transition:bool=True) -> torch.Tensor:
"""
Perform Projection onto Convex Sets (POCS) algorithm to reconstruct missing k-space data.
The missing data is assumed to be a zero-filled region at the beginning of the y-direction.
Parameters
----------
kdata
Input k-space data tensor of shape (..., Ny, Nx).
n_iterations
Number of iterations for the POCS algorithm.
smooth_transition
If true, apply smooth transition between known and reconstructed data.
Returns
-------
torch.Tensor
Reconstructed k-space data tensor of the same shape as input kdata.
"""
# Find size of symmetric sampled center region
center, sym_region = find_symmetric_region(kdata)
n_sym_center = sym_region.stop - sym_region.start
# Create low-res image and extract phase
kdata_lowres = torch.zeros_like(kdata)
kdata_lowres[:, sym_region, :] = kdata[:, sym_region, :] * torch.hamming_window(n_sym_center)[:, None]
angle = torch.angle(torch.fft.ifftn(kdata_lowres, dim=(-2, -1), norm='ortho'))
# Create mask for known data
mask = (kdata != 0)
# Iterate
kdata_current = kdata.clone()
for _ in range(n_iterations):
img = torch.fft.ifftn(kdata_current, dim=(-2, -1), norm='ortho')
img = torch.polar(img.abs(), angle)
kdata_new = torch.fft.fftn(img, dim=(-2, -1), norm='ortho')
kdata_current = torch.where(mask, kdata, kdata_new)
# Smooth transition between data and reconstructed data
if smooth_transition:
n_transition = (n_sym_center - 1) // 3
transition_window = torch.hann_window(n_transition)[:, None]
trans_slice = slice(center - n_transition, center)
kdata_current[:, trans_slice, :] = (
kdata[:, trans_slice, :] * transition_window +
kdata_new[:, trans_slice, :] * (1 - transition_window)
)
return kdata_current
At some point we might want to have a POCS partial fourier filling. Even if we use an iterative reconstruction algorithm, this can be used to provide a better x0 as far as I know: It uses the assumption that the phase changes only slowly can can be estimated using the fully sampled center region.