Open bwheelz36 opened 2 years ago
Working on a standalone script to demonstrate this issue. the following results in 1.7 s on linux, which is similar to the results I see for the real application (above) following the manual installation of fftw
"""
Demonstrate the behavior of finufft discussed here:
https://github.com/flatironinstitute/finufft/issues/235
https://github.com/ACRF-Image-X-Institute/mri_distortion_toolkit/issues/137
This script demonstrates this behavior in a self contained way.
"""
from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
def fiNufft_Ax(x):
"""
flatron instiute nufft
Returns A*x
equivalent to the 'notranpose' option in shanshans code
xj and yj are non uniform nonuniform source points. they are essentially the encoding signals.
sk and tk are uniform target points
# """
if x.dtype is not np.dtype('complex128'):
x = x.astype('complex128')
# y = nufft2d3(xj, yj, x, sk, tk, eps=1e-06, isign=-1)
y = Nufft_Ax_Plan.execute(x, None)
return y.flatten()
def fiNufft_Atb(x):
"""
flatron instiute nufft
This is to define the Nufft as a scipy.sparse.linalg.LinearOperator which can be used by the lsqr algorithm
see here for explanation:
https://stackoverflow.com/questions/48621407/python-equivalent-of-matlabs-lsqr-with-first-argument-a-function
Returns A'*x
equivalent to the 'tranpose' option in shanshans code
"""
y = Nufft_Atb_Plan.execute(x, None)
return y.flatten()
# set up random image
StartingImage = np.random.rand(148, 148)
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])
# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()
'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
Nufft_Ax_Plan = Plan(3, 2, 1, 1e-06, -1)
Nufft_Ax_Plan.setpts(xj, yj, None, sk, tk)
Nufft_Atb_Plan = Plan(3, 2, 1, 1e-06, 1)
Nufft_Atb_Plan.setpts(sk, tk, None, xj, yj)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=fiNufft_Ax, rmatvec=fiNufft_Atb)
StartingImage = StartingImage.flatten().astype(complex)
maxit = 20
time = []
for i in range(10):
_start_time = perf_counter()
x1 = lsqr(A, fk1, iter_lim=maxit, x0=StartingImage)
time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} u\u00B1 {np.std(time): 1.2f}s')
Case | Run time (s) |
---|---|
Linux: default fftw, pip installed finufft | 3.4 ± 0.33 |
Linux: built fftw, pip installed finufft | 1.23 ± 0.16 |
Linux: built fftw, built finufft | 1.44 ± 0.22 |
Windows | 0.45 ± 0.02s |
example which does not use finufft:
"""
same example but without any finufft dependency
"""
# from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
from pynufft import NUFFT, helper
def PyNufft_Ax(x):
y = Nufft_Ax_Plan.solve(x, solver='cg', maxiter=50)
return y
def PyNufft_Atb(x):
y = Nufft_Ax_Plan.adjoint(x)
return y
# set up random image
StartingImage = np.random.rand(148, 148)
Nrows = StartingImage.shape[0]
Ncolumns = StartingImage.shape[0]
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])
# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()
'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
yn_dis = yj/(2*np.pi)
xn_dis = xj/(2*np.pi)
#instantiate plan
# Nufft_Ax_Plan = NUFFT(helper.device_list()[0])
Nufft_Ax_Plan = NUFFT()
Kx_dis_pytorch = np.reshape(xn_dis,[Nrows, Ncolumns]) / Nrows * 2 * np.pi # [-pi, pi]
Ky_dis_pytorch = np.reshape(yn_dis,[Nrows, Ncolumns]) / Ncolumns * 2 * np.pi
indede = 0
k_xy_dis = np.zeros([Nrows * Ncolumns, 2])
for i in range(Nrows):
for j in range(Ncolumns):
k_xy_dis[indede, 0] = Kx_dis_pytorch[i, j]
k_xy_dis[indede, 1] = Ky_dis_pytorch[i, j]
indede = indede + 1
om = np.vstack([xj, yj])
Nd = StartingImage.shape # image size
Kd = k_space.shape # kspace size
Jd = (3, 3) # interpolation size
Nufft_Ax_Plan.plan(k_xy_dis, Nd, Kd, Jd)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=PyNufft_Ax, rmatvec=PyNufft_Atb)
time = []
for i in range(10):
_start_time = perf_counter()
x1 = lsqr(A, fk1, iter_lim=20, x0=None)
time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} \u00B1 {np.std(time): 1.2f}s')
windows: 2.18 ± 0.18s linux: 4.23 ± 1.36s
def dummy_Ax(x):
sleep(.1)
y = x.reshape(148, 148) **2
return y
def dummmy_Atb(x):
sleep(.1)
y = np.sqrt(x.reshape(148, 148))
return y
This script is to allow someone to reproduce the issue described here
set up environment and sample data
Run example
copy the below into a new file at MRI_DistortionQA root: