Image-X-Institute / mri_distortion_toolkit

Characterisation and reporting of geometric distortion in MRI
https://image-x-institute.github.io/mri_distortion_toolkit/
GNU General Public License v3.0
9 stars 3 forks source link

Slow finufft on linux #137

Open bwheelz36 opened 2 years ago

bwheelz36 commented 2 years ago

This script is to allow someone to reproduce the issue described here

set up environment and sample data

git clone https://github.com/ACRF-Image-X-Institute/MRI_DistortionQA.git
cd MRI_DistortionQA/
git checkout distortion_correction
python3 -m venv venv
source venv/bin/activate
pip3 install -U pip
pip3 install -U setuptools
pip install -r dev_requirements.txt

# replace default finufft with compiled nufft
pip uninstall finufft
git clone https://github.com/flatironinstitute/finufft.git
cd finufft/
make test
make python
cd ..

# get sample data (300 Mb)
wget https://cloudstor.aarnet.edu.au/plus/s/Wm9vndV47u941JU/download
unzip download
rm download

Run example

copy the below into a new file at MRI_DistortionQA root:

from pathlib import Path
from MRI_DistortionQA.MarkerAnalysis import MarkerVolume
from MRI_DistortionQA.MarkerAnalysis import MatchedMarkerVolumes
from MRI_DistortionQA.FieldCalculation import ConvertMatchedMarkersToBz
from MRI_DistortionQA import calculate_harmonics
import numpy as np
from MRI_DistortionQA.K_SpaceCorrector import KspaceDistortionCorrector

# Data import
dis_data_loc = Path(r'MRI_distortion_QA_sample_data/MR/04 gre_trans_AP_330')
gt_data_loc = Path(r'MRI_distortion_QA_sample_data/CT/slicer_centroids.mrk.json')

# extract markers:
gt_volume = MarkerVolume(gt_data_loc, r_max=300)
dis_volume = MarkerVolume(dis_data_loc, n_markers_expected=336, iterative_segmentation=True)
# match markers:
matched_volume = MatchedMarkerVolumes(gt_volume, dis_volume, n_refernce_markers=11)
# calculate fields
B_fields = ConvertMatchedMarkersToBz(matched_volume.MatchedCentroids, dis_volume.dicom_data)
# calculate harmonics
gradient_strength = np.array(dis_volume.dicom_data['gradient_strength'])
normalisation_factor = [1 / gradient_strength[0], 1 / gradient_strength[1], 1 / gradient_strength[2],
                        1]  # this normalised gradient harmonics to 1mT/m
# normalisation_factor = [1,1,1,1]
G_x_Harmonics, G_y_Harmonics, G_z_Harmonics, B0_Harmonics = calculate_harmonics(B_fields.MagneticFields,
                                                                                n_order=8,
                                                                                norm=normalisation_factor)

# correct input images
GDC = KspaceDistortionCorrector(ImageDirectory=dis_data_loc.resolve(),
                                Gx_Harmonics=G_x_Harmonics.harmonics,
                                Gy_Harmonics=G_y_Harmonics.harmonics,
                                Gz_Harmonics=G_z_Harmonics.harmonics,
                                ImExtension='dcm',
                                dicom_data=dis_volume.dicom_data,
                                correct_through_plane=False)
GDC.correct_all_images()
bwheelz36 commented 2 years ago

Working on a standalone script to demonstrate this issue. the following results in 1.7 s on linux, which is similar to the results I see for the real application (above) following the manual installation of fftw

"""
Demonstrate the behavior of finufft discussed here:
https://github.com/flatironinstitute/finufft/issues/235
https://github.com/ACRF-Image-X-Institute/mri_distortion_toolkit/issues/137

This script demonstrates this behavior in a self contained way.

"""
from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter

def fiNufft_Ax(x):
    """
    flatron instiute nufft
    Returns A*x
    equivalent to the 'notranpose' option in shanshans code
    xj and yj are non uniform nonuniform source points. they are essentially the encoding signals.
    sk and tk are uniform target points
    # """
    if x.dtype is not np.dtype('complex128'):
        x = x.astype('complex128')
    # y = nufft2d3(xj, yj, x, sk, tk, eps=1e-06, isign=-1)
    y = Nufft_Ax_Plan.execute(x, None)
    return y.flatten()

def fiNufft_Atb(x):
    """
    flatron instiute nufft
    This is to define the Nufft as a scipy.sparse.linalg.LinearOperator which can be used by the lsqr algorithm
    see here for explanation:
    https://stackoverflow.com/questions/48621407/python-equivalent-of-matlabs-lsqr-with-first-argument-a-function
    Returns A'*x
    equivalent to the 'tranpose' option in shanshans code
    """
    y = Nufft_Atb_Plan.execute(x, None)
    return y.flatten()

# set up random image
StartingImage = np.random.rand(148, 148)
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])

# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()

'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size

Nufft_Ax_Plan = Plan(3, 2, 1, 1e-06, -1)
Nufft_Ax_Plan.setpts(xj, yj, None, sk, tk)
Nufft_Atb_Plan = Plan(3, 2, 1, 1e-06, 1)
Nufft_Atb_Plan.setpts(sk, tk, None, xj, yj)

A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=fiNufft_Ax, rmatvec=fiNufft_Atb)
StartingImage = StartingImage.flatten().astype(complex)
maxit = 20
time = []
for i in range(10):
    _start_time = perf_counter()
    x1 = lsqr(A, fk1, iter_lim=maxit, x0=StartingImage)
    time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} u\u00B1 {np.std(time): 1.2f}s')
bwheelz36 commented 2 years ago
Case Run time (s)
Linux: default fftw, pip installed finufft 3.4 ± 0.33
Linux: built fftw, pip installed finufft 1.23 ± 0.16
Linux: built fftw, built finufft 1.44 ± 0.22
Windows 0.45 ± 0.02s
bwheelz36 commented 2 years ago

example which does not use finufft:

"""
same example but without any finufft dependency

"""
# from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
from pynufft import NUFFT, helper

def PyNufft_Ax(x):
    y = Nufft_Ax_Plan.solve(x, solver='cg', maxiter=50)
    return y

def PyNufft_Atb(x):
    y = Nufft_Ax_Plan.adjoint(x)
    return y

# set up random image
StartingImage = np.random.rand(148, 148)
Nrows = StartingImage.shape[0]
Ncolumns = StartingImage.shape[0]
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])

# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()

'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
yn_dis = yj/(2*np.pi)
xn_dis = xj/(2*np.pi)

#instantiate plan
# Nufft_Ax_Plan = NUFFT(helper.device_list()[0])
Nufft_Ax_Plan = NUFFT()

Kx_dis_pytorch = np.reshape(xn_dis,[Nrows, Ncolumns]) / Nrows * 2 * np.pi  # [-pi, pi]
Ky_dis_pytorch = np.reshape(yn_dis,[Nrows, Ncolumns]) / Ncolumns * 2 * np.pi

indede = 0
k_xy_dis = np.zeros([Nrows * Ncolumns, 2])
for i in range(Nrows):
    for j in range(Ncolumns):
        k_xy_dis[indede, 0] = Kx_dis_pytorch[i, j]
        k_xy_dis[indede, 1] = Ky_dis_pytorch[i, j]
        indede = indede + 1

om = np.vstack([xj, yj])
Nd = StartingImage.shape  # image size
Kd = k_space.shape  # kspace size
Jd = (3, 3)  # interpolation size
Nufft_Ax_Plan.plan(k_xy_dis, Nd, Kd, Jd)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=PyNufft_Ax, rmatvec=PyNufft_Atb)
time = []
for i in range(10):
    _start_time = perf_counter()
    x1 = lsqr(A, fk1, iter_lim=20, x0=None)
    time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} \u00B1 {np.std(time): 1.2f}s')
bwheelz36 commented 2 years ago

linux_profile.txt

bwheelz36 commented 2 years ago

windows: 2.18 ± 0.18s linux: 4.23 ± 1.36s

bwheelz36 commented 2 years ago
def dummy_Ax(x):
    sleep(.1)
    y = x.reshape(148, 148) **2
    return y

def dummmy_Atb(x):
    sleep(.1)
    y = np.sqrt(x.reshape(148, 148))
    return y