GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
25 stars 3 forks source link

ENH add InterpolatedImage #60

Closed beckermr closed 10 months ago

beckermr commented 12 months ago

This PR adds interpolated images.

Closes #54

beckermr commented 11 months ago

This PR is blocked on figuring out the differences in this galsim issue: https://github.com/GalSim-developers/GalSim/issues/1248

beckermr commented 10 months ago

metacal in jax_galsim matches galsim to an absolute accuracy of 5e-5:

mcal_jax_galsim

beckermr commented 10 months ago

Unfortunately I cannot jit it!

beckermr commented 10 months ago

The workspace stuff is not working

(work) beckermr@finnegan JAX-GalSim % pytest -vvx -k test_conserve_dc
========================================================================== test session starts ===========================================================================
platform darwin -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0 -- /Users/beckermr/mambaforge/envs/work/bin/python3.10
cachedir: .pytest_cache
rootdir: /Users/beckermr/Desktop/JAX-GalSim
configfile: pyproject.toml
testpaths: tests/GalSim/tests/, tests/jax, tests/Coord/tests/
collected 1547 items / 1546 deselected / 1 selected                                                                                                                      

tests/GalSim/tests/test_interpolatedimage.py::test_conserve_dc FAILED                                                                                              [100%]

================================================================================ FAILURES ================================================================================
____________________________________________________________________________ test_conserve_dc ____________________________________________________________________________
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

    @timer
    def test_conserve_dc():
        """Test that the conserve_dc option for Lanczos does so.
        Note: the idea of conserving flux is a bit of a misnomer.  No interpolant does so
        precisely in general.  What we are really testing is that a flat background input
        image has a relatively flat output image.
        """
        im1_size = 40
        scale1 = 0.23
        init_val = 1.

        im2_size = 100
        scale2 = 0.011

        im1 = galsim.ImageF(im1_size, im1_size, scale=scale1, init_value=init_val)

        # im2 has a much smaller scale, but the same size, so interpolating an "infinite"
        # constant field.
        im2 = galsim.ImageF(im2_size, im2_size, scale=scale2)

        for interp in ['linear', 'cubic', 'quintic']:
            print('Testing interpolant ',interp)
            obj = galsim.InterpolatedImage(im1, x_interpolant=interp, normalization='sb')
            obj.drawImage(im2, method='sb')
            print('The maximum error is ',np.max(abs(im2.array-init_val)))
            np.testing.assert_array_almost_equal(
                    im2.array,init_val,5,
                    '%s did not preserve a flat input flux using xvals.'%interp)

            # Convolve with a delta function to force FFT drawing.
            delta = galsim.Gaussian(sigma=1.e-8)
            obj2 = galsim.Convolve([obj,delta])
            obj2.drawImage(im2, method='sb')
            print('The maximum error is ',np.max(abs(im2.array-init_val)))
            np.testing.assert_array_almost_equal(
                    im2.array,init_val,5,
                    '%s did not preserve a flat input flux using uvals.'%interp)

            check_pickle(obj, lambda x: x.drawImage(method='no_pixel'))
            check_pickle(obj2, lambda x: x.drawImage(method='no_pixel'))
            check_pickle(obj)
            check_pickle(obj2)

        for n in [3,4,5,6,7,8]:  # n=8 tests the generic formulae, since not specialized.
            print('Testing Lanczos interpolant with n = ',n)
            lan = galsim.Lanczos(n, conserve_dc=True)
            obj = galsim.InterpolatedImage(im1, x_interpolant=lan, normalization='sb')
            obj.drawImage(im2, method='sb')
            print('The maximum error is ',np.max(abs(im2.array-init_val)))
            np.testing.assert_array_almost_equal(
                    im2.array,init_val,5,
                    'Lanczos %d did not preserve a flat input flux using xvals.'%n)

            # Convolve with a delta function to force FFT drawing.
            delta = galsim.Gaussian(sigma=1.e-8)
            obj2 = galsim.Convolve([obj,delta])
            obj2.drawImage(im2, method='sb')
            print('The maximum error is ',np.max(abs(im2.array-init_val)))
            np.testing.assert_array_almost_equal(
                    im2.array,init_val,5,
                    'Lanczos %d did not preserve a flat input flux using uvals.'%n)

>           check_pickle(obj, lambda x: x.drawImage(method='no_pixel'))

tests/GalSim/tests/test_interpolatedimage.py:1323: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/conftest.py:37: in _check_pickle
    return orig_check_pickle(*args, **kwargs)
../../mambaforge/envs/work/lib/python3.10/site-packages/galsim/utilities.py:1590: in check_pickle
    f1 = func(obj)
tests/GalSim/tests/test_interpolatedimage.py:1323: in <lambda>
    check_pickle(obj, lambda x: x.drawImage(method='no_pixel'))
jax_galsim/gsobject.py:713: in drawImage
    added_photons, image = prof.drawReal(image, add_to_image)
jax_galsim/gsobject.py:735: in drawReal
    im1 = self._drawReal(image)
jax_galsim/transform.py:367: in _drawReal
    return self._original._drawReal(image, jac, (dx, dy), flux_scaling)
jax_galsim/transform.py:367: in _drawReal
    return self._original._drawReal(image, jac, (dx, dy), flux_scaling)
jax_galsim/interpolatedimage.py:838: in _drawReal
    return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling)
jax_galsim/core/draw.py:27: in draw_by_xValue
    im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))(
jax_galsim/core/draw.py:27: in <lambda>
    im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))(
jax_galsim/interpolatedimage.py:792: in _xValue
    vals = _draw_with_interpolant_xval(
jax_galsim/interpolatedimage.py:911: in _draw_with_interpolant_xval
    z = jax.lax.fori_loop(
jax_galsim/interpolatedimage.py:905: in _body
    wy = interp._xval_noraise(_y)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = galsim.Lanczos(3, True, gsparams=galsim.GSParams(128,8192,0.005,5.0,0.001,1e-05,1e-05,1,0.0001,1e-06,1e-06,1e-08,1e-05))
x = Traced<ShapedArray(float64[44])>with<DynamicJaxprTrace(level=3/0)>

    def _xval_noraise(self, x):
>       return Lanczos._xval(x, self._n, self._conserve_dc, self._K_arr)
E       jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[6] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
E       JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
E       The function being traced when the value leaked was scanned_fun at /Users/beckermr/mambaforge/envs/work/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1836 traced for scan.
E       ------------------------------
E       The leaked intermediate value was created on line /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolant.py:1347 (_K_arr). 
E       ------------------------------
E       When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
E       ------------------------------
E       /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolatedimage.py:792 (_xValue)
E       /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolatedimage.py:911 (_draw_with_interpolant_xval)
E       /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolatedimage.py:905 (_body)
E       /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolant.py:1490 (_xval_noraise)
E       /Users/beckermr/Desktop/JAX-GalSim/jax_galsim/interpolant.py:1347 (_K_arr)
E       ------------------------------
E       
E       To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
E       See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

jax_galsim/interpolant.py:1490: UnexpectedTracerError
-------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------
Testing interpolant  linear
The maximum error is  0.0
The maximum error is  5.841255e-06
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Testing interpolant  cubic
The maximum error is  0.0
The maximum error is  5.9604645e-06
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Testing interpolant  quintic
The maximum error is  0.0
The maximum error is  5.9604645e-06
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
Try pickling  galsim.Convolve(galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64), galsim.Gaussian(sigma=1e-08, flux=1.0))
Testing Lanczos interpolant with n =  3
The maximum error is  7.2717667e-06
The maximum error is  1.2695789e-05
Try pickling  galsim.InterpolatedImage(image=galsim.Image(bounds=galsim.BoundsI(-20,19,-20,19), scale=0.23, dtype=numpy.float64), flux=84.64)
======================================================================== short test summary info =========================================================================
FAILED tests/GalSim/tests/test_interpolatedimage.py::test_conserve_dc - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate valu...
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
================================================================== 1 failed, 1546 deselected in 11.21s ===================================================================

cc @jecampagne @EiffL for viz and advice

beckermr commented 10 months ago

So I removed the caching here and the test time has gone up to ~20-30 minutes. TBH, jax doesn't support the cache mechanisms in use in jax-cosmo and the workarounds are not so wonderful. So my preference is to leave things as is.

beckermr commented 10 months ago

Once we've applied JIT, metacal runs in ~40ms on my CPU compared to 5ms for the galsim code. I have not tested what happens on a GPU. This is not great but not too bad IMHO for a first pass.

The main issue right now is that it takes 10s-15s to JIT the combination of operations. :/

beckermr commented 10 months ago

Here's another look at the errors. The top row is the PSF interpolated and then redrawn with a pixel. The middle row is the same profile but drawn without a pixel (pure real-space drawing). The bottom row is the metacal result. The residual plots on the right are all in units of 1e-5.

Screenshot 2023-10-26 at 4 27 42 PM

Very clearly something is happening in the Fourier space versions. It is either the InterpolatedImage class itself (though I do test the kValues against galsim directly), or maybe that we don't have k-space wrapping working for images yet.

jecampagne commented 10 months ago

Hi @beckermr for the metacal result abve (although I am not an expert) may be a test concerning the wrapping effect: 1) if you have in hand the galsim code, is it possible to get a version w/o the k-space wrapping? 2) one other option is to convolve the image with a gaussian that cut the power at high k before the metacal interpolation ? and see if the result change. Sorry if it may happen to be silly reasoning.

jecampagne commented 10 months ago

Now concerning 10-15sec to get the JITted version is rather short :) when we envisage to connect to inference tools (numpyro & Co) the compilation time is far larger.

beckermr commented 10 months ago

Good suggestions @jecampagne! I also have PR #65 open which has a version of the k-space wrapping. I might make a new branch and try that PR there with this one there.

beckermr commented 10 months ago

OK. I figured out the issue here. You need the k-space wrapping and you need to set maxk and stepk to be the same values. Those values for interpolated images in jax-galsim differ a little from galsim and cause rendering differences under convolutions. Once you do this, metacal agrees to better than 1e-7 in absolute tolerance.

beckermr commented 10 months ago

Alright @ismael-mendoza. This PR is a monster, but is ready for review. If you have suggestions on how to break it into smaller chunks, I'd be happy to do that.

ismael-mendoza commented 10 months ago

Thanks! will start taking a look later today and let you know if I have suggestions on how to split it

beckermr commented 10 months ago

OK @ismael-mendoza. Back to you!

beckermr commented 10 months ago

OK @ismael-mendoza ready for one last look!