odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
356 stars 105 forks source link

FBP not scaled when PYFFTW_AVAILABLE is True #1617

Open Yao1993 opened 2 years ago

Yao1993 commented 2 years ago
import odl
import numpy as np
from odl.trafos import PYFFTW_AVAILABLE

apart = odl.uniform_partition(0, 2 * np.pi, 360)
dpart = odl.uniform_partition(-0.75, 0.75, 500)
geom = odl.tomo.geometry.conebeam.FanBeamGeometry(apart, dpart, src_radius=0.5, det_radius=0.5)
space = odl.uniform_discr(
    min_pt=[-0.25, -0.25], max_pt=[0.25, 0.25], shape=[512, 512],
    dtype='float32')

ray_trafo = odl.tomo.RayTransform(space, geom, impl="astra_cuda")
fbp = odl.tomo.fbp_op(ray_trafo , filter_type='Hann', frequency_scaling=0.8)

phantom = np.ones((512, 512), dtype="float32")
projection = ray_trafo(phantom)

recon = fbp(projection)
print(f"PYFFTW_AVAILABLE={PYFFTW_AVAILABLE}, Recon Mean={recon.asarray().mean()}")
recon.show();

When pyfftw not installed, the output is image When pyfftw installed, the output is image


My environment is
# Name                    Version                   Build  Channel
python                    3.9.12          h9a8a25e_1_cpython    conda-forge
pyfftw                    0.13.0           py39h51d1ae8_0    conda-forge
odl                       1.0.0.dev0               pypi_0    pypi # downloaded from github, installed by `python setup.py install`

Update: When pyfftw==0.12.0, the scale is correct.

image

ozanoktem commented 2 years ago

There seems to be a scaling issue in how FFT is used in the context of FBP.

ozanoktem commented 2 years ago

Just to make sure this is totally unrelated to ASTRA (which I think it is), can you @Yao1993 try the same example but now using scikit images back-end instead of ASTRA? For this, you would need to switch geometry to parallel beam.

Yao1993 commented 2 years ago
import odl
import numpy as np
from odl.trafos import PYFFTW_AVAILABLE

apart = odl.uniform_partition(0, 2 * np.pi, 360)
dpart = odl.uniform_partition(-0.75, 0.75, 500)
geom = odl.tomo.Parallel2dGeometry(apart, dpart)
# geom = odl.tomo.geometry.conebeam.FanBeamGeometry(apart, dpart, src_radius=0.5, det_radius=0.5)
space = odl.uniform_discr(
    min_pt=[-0.25, -0.25], max_pt=[0.25, 0.25], shape=[512, 512],
    dtype='float32')

ray_trafo = odl.tomo.RayTransform(space, geom, impl="skimage")
fbp = odl.tomo.fbp_op(ray_trafo , filter_type='Hann', frequency_scaling=0.8)

phantom = np.ones((512, 512), dtype="float32")
projection = ray_trafo(phantom)

recon = fbp(projection)
print(f"PYFFTW_AVAILABLE={PYFFTW_AVAILABLE}, Recon Mean={recon.asarray().mean()}")
recon.show();

When pyfftw==0.12.0, the result is image

When pyfftw==0.13.0, the result is image

@ozanoktem

Yao1993 commented 2 years ago

The torch.fft module has nearly the same apis as numpy.fft. I just replace numpy.fft with torch.fft in FourierTransform._call_numpy, and find the modified backend is even faster than the pyfftw backend. Could odl add a torch.fft backend for FourierTransform?