GalSim-developers / GalSim

The modular galaxy image simulation toolkit. Documentation:
http://galsim-developers.github.io/GalSim/
Other
224 stars 105 forks source link

BUG in FFTs? #1248

Closed beckermr closed 11 months ago

beckermr commented 11 months ago

I very much hope I am wrong about this one, but I cannot figure out what is going on, so here we go.

Here is the test case

def _compute_fft_with_numpy_galsim(im):
    import numpy as np
    from galsim import BoundsI, Image
    No2 = max(-im.bounds.xmin, im.bounds.xmax + 1, -im.bounds.ymin, im.bounds.ymax + 1)

    full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1)
    if im.bounds == full_bounds:
        # Then the image is already in the shape we need.
        ximage = im
    else:
        # Then we pad out with zeros
        ximage = Image(full_bounds, dtype=im.dtype, init_value=0)
        ximage[im.bounds] = im[im.bounds]

    dx = im.scale
    # dk = 2pi / (N dk)
    dk = np.pi / (No2 * dx)

    out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk)
    out._array = np.fft.fftshift(np.fft.rfft2(ximage.array), axes=0)
    out *= dx * dx
    out.setOrigin(0, -No2)
    return out

@pytest.mark.parametrize("n", [5, 4])
def test_galsim_fft_vs_numpy(n):
    import numpy as np
    import galsim

    rng = np.random.RandomState(42)
    arr = rng.normal(size=(n, n))
    im = galsim.Image(arr, scale=1)
    kim = im.calculate_fft()
    xkim = kim.calculate_inverse_fft()

    np.testing.assert_allclose(im.array, xkim[im.bounds].array)

    np_kim = _compute_fft_with_numpy_galsim(im)
    print("ratio real:\n", np_kim.array.real / kim.array.real)
    print("ratio imag:\n", np_kim.array.imag / kim.array.imag)
    np.testing.assert_allclose(kim.array, np_kim.array)

and the output

% pytest -vvs -k test_galsim_fft_vs_numpy
================================================================================= test session starts ==================================================================================
platform darwin -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0 -- /Users/beckermr/mambaforge/envs/work/bin/python3.10
cachedir: .pytest_cache
rootdir: /Users/beckermr/Desktop/JAX-GalSim
configfile: pytest.ini
testpaths: tests/GalSim/tests/, tests/jax
collected 241 items / 239 deselected / 2 selected                                                                                                                                      

tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[5] ratio real:
 [[ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]]
ratio imag:
 [[inf -1.  1. -1.  1. -1. inf]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [nan -1.  1. -1.  1. -1. nan]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]
 [ 1. -1.  1. -1.  1. -1.  1.]
 [-1.  1. -1.  1. -1.  1. -1.]]
FAILED
tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[4] ratio real:
 [[-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]]
ratio imag:
 [[inf  1. -1.  1. -1. inf]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [nan -1.  1. -1.  1. nan]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.]]
FAILED

======================================================================================= FAILURES =======================================================================================
_____________________________________________________________________________ test_galsim_fft_vs_numpy[5] ______________________________________________________________________________

n = 5

    @pytest.mark.parametrize("n", [5, 4])
    def test_galsim_fft_vs_numpy(n):
        import numpy as np
        import galsim

        rng = np.random.RandomState(42)
        arr = rng.normal(size=(n, n))
        im = galsim.Image(arr, scale=1)
        kim = im.calculate_fft()
        xkim = kim.calculate_inverse_fft()

        np.testing.assert_allclose(im.array, xkim[im.bounds].array)

        np_kim = _compute_fft_with_numpy_galsim(im)
        print("ratio real:\n", np_kim.array.real / kim.array.real)
        print("ratio imag:\n", np_kim.array.imag / kim.array.imag)
>       np.testing.assert_allclose(kim.array, np_kim.array)

tests/jax/test_interpolatedimage_utils.py:253: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x11f03c8b0>, array([[ 1.29649577+0.j        , -2.92335712-1.38166268j,...5486278e+00j,
         3.20398239-2.17843793e+00j,  0.36882966+9.13552566e-01j,
         2.54946424-1.31877006e+00j]]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 42 / 84 (50%)
E           Max absolute difference: 17.07566526
E           Max relative difference: 2.
E            x: array([[ 1.296496+0.j      , -2.923357-1.381663j, -1.372159+2.572466j,
E                   -0.958652+0.786162j, -0.461493+4.2329j  ,  3.882009+1.794316j,
E                    2.370809+0.j      ],...
E            y: array([[ 1.296496+4.440892e-16j,  2.923357+1.381663e+00j,
E                   -1.372159+2.572466e+00j,  0.958652-7.861625e-01j,
E                   -0.461493+4.232900e+00j, -3.882009-1.794316e+00j,...

../../mambaforge/envs/work/lib/python3.10/contextlib.py:79: AssertionError
_____________________________________________________________________________ test_galsim_fft_vs_numpy[4] ______________________________________________________________________________

n = 4

    @pytest.mark.parametrize("n", [5, 4])
    def test_galsim_fft_vs_numpy(n):
        import numpy as np
        import galsim

        rng = np.random.RandomState(42)
        arr = rng.normal(size=(n, n))
        im = galsim.Image(arr, scale=1)
        kim = im.calculate_fft()
        xkim = kim.calculate_inverse_fft()

        np.testing.assert_allclose(im.array, xkim[im.bounds].array)

        np_kim = _compute_fft_with_numpy_galsim(im)
        print("ratio real:\n", np_kim.array.real / kim.array.real)
        print("ratio imag:\n", np_kim.array.imag / kim.array.imag)
>       np.testing.assert_allclose(kim.array, np_kim.array)

tests/jax/test_interpolatedimage_utils.py:253: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x13e958b80>, array([[-3.75327258+0.j        , -0.01286824+3.25295301j,...2876+8.12743864e-01j, -0.3500387 -5.01804082e+00j,
         4.42163476+1.90961627e-01j, -0.66975573+2.45705883e+00j]]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 30 / 60 (50%)
E           Max absolute difference: 14.13049298
E           Max relative difference: 2.
E            x: array([[-3.753273+0.j      , -0.012868+3.252953j,  2.062001+0.513939j,
E                    1.540098-0.864889j, -0.185365-2.602459j, -3.054459-0.j      ],
E                  [ 1.472279+1.212782j,  1.221602-1.327474j, -1.618268-1.196882j,...
E            y: array([[ 3.753273+2.220446e-16j, -0.012868+3.252953e+00j,
E                   -2.062001-5.139386e-01j,  1.540098-8.648888e-01j,
E                    0.185365+2.602459e+00j, -3.054459-5.551115e-17j],...

../../mambaforge/envs/work/lib/python3.10/contextlib.py:79: AssertionError
=================================================================================== warnings summary ===================================================================================
tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[5]
tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[4]
  /Users/beckermr/Desktop/JAX-GalSim/tests/jax/test_interpolatedimage_utils.py:252: RuntimeWarning: divide by zero encountered in divide
    print("ratio imag:\n", np_kim.array.imag / kim.array.imag)

tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[5]
tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[4]
  /Users/beckermr/Desktop/JAX-GalSim/tests/jax/test_interpolatedimage_utils.py:252: RuntimeWarning: invalid value encountered in divide
    print("ratio imag:\n", np_kim.array.imag / kim.array.imag)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================== short test summary info ================================================================================
FAILED tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[5] - AssertionError: 
FAILED tests/jax/test_interpolatedimage_utils.py::test_galsim_fft_vs_numpy[4] - AssertionError: 
==================================================================== 2 failed, 239 deselected, 4 warnings in 1.54s =====================================================================

As you can see, the checkerboard pattern of -1/1 that is supposed to be removed remains.

beckermr commented 11 months ago

cc @rmjarvis

rmjarvis commented 11 months ago

Before I go too far trying to understand your test, have you looked at this unit test that compares our fft commands with numpy?

https://github.com/GalSim-developers/GalSim/blob/releases/2.4/tests/test_draw.py#L1342

beckermr commented 11 months ago

I have not. This is definitely due to the shifting of the center of the spectrum so I suspect it is a convention issue that won't effect drawing itself.

rmjarvis commented 11 months ago

We do have the convention that the FFT image is centered in the center of the image, not at the corner (ie. kx,ky = 0,0). That's probably the source of disagreement. But we have functions like galsim.rfft2 that are meant to be drop in replacements of np.rfft2, but using FFTW for the back end. (For many purposes, FFTW is a bit faster, but YMMV.) Anyway, looking at those routines might be helpful for teasing out the conventions.

beckermr commented 11 months ago

Ahhhh thanks for the pointers! The test suite indeed had the right answer in it due to shifts etc. For the record, it is np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) that is needed above!