PTB-MR / mrpro

MR image reconstruction and processing.
https://ptb-mr.github.io/mrpro/
Apache License 2.0
13 stars 3 forks source link

Undersampled Image Reconstruction using operators #381

Open Pierrickkk opened 1 month ago

Pierrickkk commented 1 month ago

I am generating k-space data using the Ellipses phantom and performing direct reconstruction to obtain a fully sampled image, which serves as the ground truth from which I want to retrospectively simulate k-space data.

To do that, I split the k-space data to create a direct_reconstruction object from the undersampled k-space data.

I then create an acquisition operator to retrospectively generate the k-space data from the ground-truth image.

The image reconstructed from this undersampled data (xu) should ideally match the direct reconstruction of the undersampled data (x_us). More precisely, they should also match the ground-truth image in terms of their overall identity, such that when I compute the point-wise error images, only artefacts are left.

The problem is that while x_us matches the fully sampled ground truth (x_fullsamp) and looks correct, xu isn't. The issue persists despite using the density compensation function and Fourier transform from the direct reconstruction.

This is a code that you can try :

# %% NOT IMPORTANT Import and data 

import mrpro
from mrpro.data import SpatialDimension
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction
from mrpro.data._kdata.KData import KData  
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.traj_calculators.KTrajectoryPulseq import KTrajectoryPulseq
from mrpro.operators.FourierOp import FourierOp
from mrpro.phantoms.EllipsePhantom import EllipsePhantom  
from mrpro.phantoms.phantom_elements import EllipseParameters
from mrpro.operators.DensityCompensationOp import DensityCompensationOp

base_path = '/home/global/mrpro_issue381/'
h5_path = base_path + 'pulseq_spiral_2D_220k0_128interleaves_golden_angle_vds_with_traj.h5'
seq_path = base_path + '20240319_spiral_2D_256mm_220k0_128interleaves_golden_angle_vds.seq'

def shift_k_space_trajectory(kdatapuls: KData) -> KData:

    # Extract k-space trajectory
    ky_pulseq = kdatapuls.traj.ky
    kx_pulseq = kdatapuls.traj.kx
    kz_pulseq = kdatapuls.traj.kz

    # Number of indices and initilisation
    num_indices = ky_pulseq.shape[2]
    shifted_ky = ky_pulseq.clone()
    shifted_kx = kx_pulseq.clone()

    for i in range(num_indices - 1):
        shifted_ky[:, :, i, :] -= ky_pulseq[:, :, i, 0]
        shifted_kx[:, :, i, :] -= kx_pulseq[:, :, i, 0]

    shifted_traj = KTrajectory(kx=shifted_kx, ky=shifted_ky, kz=kz_pulseq)
    shifted_kdatapuls = KData(data=kdatapuls.data, traj=shifted_traj, header=kdatapuls.header)

    return shifted_kdatapuls

kdatapuls = KData.from_file(h5_path, KTrajectoryPulseq(seq_path=seq_path))
shifted_kdatapuls = shift_k_space_trajectory(kdatapuls)

shifted_kdatapuls.header.recon_matrix.x = 256
shifted_kdatapuls.header.recon_matrix.y = 256

# %% NOT IMPORTANT generate_random_ellipses

def generate_random_ellipses(num_ellipses):
    ellipses = []
    for _ in range(num_ellipses):
        # Generate radius first to use it for constraining center coordinates
        radius_x = np.random.uniform(0.05, 0.4)
        radius_y = np.random.uniform(0.05, 0.4)

        # Calculate bounds to keep the ellipse within
        min_center_x = -0.4 + radius_x 
        max_center_x = 0.4 - radius_x
        min_center_y = -0.4 + radius_y
        max_center_y = 0.4 - radius_y

        center_x = np.random.uniform(min_center_x, max_center_x)
        center_y = np.random.uniform(min_center_y, max_center_y)

        intensity = np.random.uniform(1, 50)
        ellipses.append(EllipseParameters(center_x, center_y, radius_x, radius_y, intensity))

    return ellipses

# %% Look here
ellipses = generate_random_ellipses(8)
phantom = EllipsePhantom(ellipses)
kspace_data = phantom.kspace(shifted_kdatapuls.traj.ky, shifted_kdatapuls.traj.kx)

kdata_object = KData(
    data=kspace_data.unsqueeze(0), header=shifted_kdatapuls.header, traj=shifted_kdatapuls.traj
)

kdata_object_us = KData.split_k1_into_other(kdata_object, torch.arange(0, 128, 4)[None, :], other_label='repetition')

direct_reconstruction_fullsamp = DirectReconstruction.from_kdata(kdata_object)
x_fullsamp = direct_reconstruction_fullsamp(kdata_object)

direct_reconstruction_us = DirectReconstruction.from_kdata(kdata_object_us)
x_us = direct_reconstruction_us(kdata_object_us)

fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata=kdata_object_us)
dcf_operator = DensityCompensationOp(torch.sqrt(direct_reconstruction_us.dcf.data)) 
acquisition_operator = dcf_operator @ fourier_operator

kdata = acquisition_operator(x_fullsamp.data)[0] 
(xu,) = acquisition_operator.H(kdata)              
koflera commented 4 weeks ago

I put the necessary data to run the code on the reco-cluster and adapted the base_path in the code to be able to load the data.

koflera commented 4 weeks ago

Ok, I think I found an at least easy fix such that you can continue working on your stuff in the meanwhile. When defining the dcf_operator from scratch, the correct thing to do seems to be to define it by taking 2 * direct_reconstruction_us.data as the dcf to construct the operator from. Then, the images are all in the same range and the difference images do only contain artefacts related to the undersampling, see the figure below. Not sure where the additional factor of 2 comes from, though.

issue381

ellipses = generate_random_ellipses(8)
phantom = EllipsePhantom(ellipses)
kspace_data = phantom.kspace(shifted_kdatapuls.traj.ky, shifted_kdatapuls.traj.kx)

kdata_object = KData(data=kspace_data.unsqueeze(0), header=shifted_kdatapuls.header, traj=shifted_kdatapuls.traj)

kdata_object_us = KData.split_k1_into_other(kdata_object, torch.arange(0, 128, 16)[None, :], other_label='repetition')

direct_reconstruction_fullsamp = DirectReconstruction.from_kdata(kdata_object)
x_fullsamp = direct_reconstruction_fullsamp(kdata_object).data

direct_reconstruction_us = DirectReconstruction.from_kdata(kdata_object_us)
x_us = direct_reconstruction_us(kdata_object_us).data

fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata=kdata_object_us)
# dcf_operator = DensityCompensationOp(direct_reconstruction_us.dcf.data)  # BEFORE
dcf_operator = DensityCompensationOp(2 * torch.sqrt(direct_reconstruction_us.dcf.data))  # CORRECTED
acquisition_operator = dcf_operator @ fourier_operator

(kdata,) = acquisition_operator(x_fullsamp.data)
(xu,) = acquisition_operator.H(kdata)

fig, ax = plt.subplots(2, 3)
arrs = [xu, x_us, x_fullsamp]
errs = [arr - x_fullsamp for arr in arrs]
titles = ['retrospecively\n undersampled', 'direct recon from\n split data', 'ground-truth']
clim = [0, 100]
for k, (arr, err, title) in enumerate(zip(arrs, errs, titles)):
    ax[0, k].set_title(title, fontsize=7)
    ax[0, k].imshow(arr[0, 0, 0, ...].abs(), clim=clim)
    ax[1, k].imshow(3 * err[0, 0, 0, ...].abs(), clim=clim)
koflera commented 4 weeks ago

I have tried now out the following as well.


ellipses = generate_random_ellipses(8)
phantom = EllipsePhantom(ellipses)
kspace_data = phantom.kspace(shifted_kdatapuls.traj.ky, shifted_kdatapuls.traj.kx)

kdata_object = KData(data=kspace_data.unsqueeze(0), header=shifted_kdatapuls.header, traj=shifted_kdatapuls.traj)

kdata_object_us = KData.split_k1_into_other(kdata_object, torch.arange(0, 128, 16)[None, :], other_label='repetition')

direct_reconstruction_fullsamp = DirectReconstruction.from_kdata(kdata_object)
x_fullsamp = direct_reconstruction_fullsamp(kdata_object).data

direct_reconstruction_us = DirectReconstruction.from_kdata(kdata_object_us)
x_us = direct_reconstruction_us(kdata_object_us).data

fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata=kdata_object_us)
# dcf_operator = DensityCompensationOp(direct_reconstruction_us.dcf.data)  # BEFORE
dcf_operator_W12 = DensityCompensationOp(2 * torch.sqrt(direct_reconstruction_us.dcf.data))  # CORRECTED
dcf_operator = DensityCompensationOp(4 * direct_reconstruction_us.dcf.data)  # CORRECTED

(kdata,) = fourier_operator(x_fullsamp.data)
(xu_v1,) = fourier_operator.H(dcf_operator(kdata)[0])

acquisition_operator = dcf_operator_W12 @ fourier_operator
(kdata_dcf,) = acquisition_operator(x_fullsamp.data)
(xu_v2,) = acquisition_operator.H(kdata_dcf)

arrs = [xu_v1, xu_v2, x_us, x_fullsamp]
errs = [arr - x_fullsamp for arr in arrs]
titles = [
    'retrospecively\n undersampled with y=Ax, A:=F; \n x0 = A^H W y',
    'retrospecively\n undersampled with y=Ax, A:=W^(1/2)F; \n x0 = A^H W^(1/2) y',
    'direct recon from\n split data',
    'ground-truth',
]
single_fig_size = 5
figsize = (len(arrs) * single_fig_size / 2, single_fig_size)
fig, ax = plt.subplots(2, len(arrs), figsize=figsize)
clim = [0, 75]
for k, (arr, err, title) in enumerate(zip(arrs, errs, titles)):
    ax[0, k].set_title(title, fontsize=7)
    ax[0, k].imshow(arr[0, 0, 0, ...].abs(), clim=clim, cmap=plt.cm.Greys_r)
    ax[1, k].imshow(err[0, 0, 0, ...].abs(), clim=clim)

plt.setp(ax, xticks=[], yticks=[])

output

This shows more in detail that apparently, what we are getting back from the direct_reconstruction, seems to correspond to a W' that we want to interpret as a W' = (W^(1/2))^H W^(1/2) with W^(1/2) = 2 * (W')^(1/2) for setting up a problem by min_x 1/2 * ||W^(1/2) (Ax - y) ||_2^2.

Maybe it would be worth noting this down in the documentation somewhere.

ckolbPTB commented 2 weeks ago

Could the problem be in this line

(kdata,) = fourier_operator(x_fullsamp.data)

and how this fourier_operator is defined?

If I e.g. change the nufft_oversampling from 2 (default) to 1, the scaling between x_u and x_us is reduced from 4 to 2. If I also change the recon_matrix to fit the encoding_matrix along the readout direction then it is further reduced but not yet 1.

Seems to me that we would need to simulate the acquisition in a way such that we do not get any additional scaling of the nufft just because we have different matrix sizes.