GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
28 stars 3 forks source link

vmap to build gaussian profile images (mainly drawing branch) #29

Closed jecampagne closed 1 year ago

jecampagne commented 1 year ago

Hi,

Using the drawing branch I experience some problems that I expose here. Some can find a temporary workaround but it might be difficult for the rest of the soft development.

Let us try to make 2 galaxy profiles using a vmap on their flux & sigma parameters:

def gen_gal(p:dict):
  gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
  return gal

print("tst vmap gen_gal")
params = {'flux': jnp.array([1.e5, 2.e5]), 'sigma': jnp.array([1.,2.])}
res = jax.vmap(gen_gal)(params)
print(res)

the final error message is

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

[/usr/local/lib/python3.8/dist-packages/jax_galsim/gaussian.py](https://localhost:8080/#) in __init__(self, half_light_radius, sigma, fwhm, flux, gsparams)
     63             super().__init__(sigma=sigma, flux=flux, gsparams=gsparams)
     64 
---> 65         self._sigsq = self.sigma**2
     66         self._inv_sigsq = 1.0 / self._sigsq
     67         self._norm = self.flux * self._inv_sigsq * Gaussian._inv_twopi

TypeError: unsupported operand type(s) for ** or pow(): 'object' and 'int'

See Here a nb on Colab

jecampagne commented 1 year ago

More

# When I call to instantiate a single gaussian profile
JEC DBG: type self.sigma <class 'float'>
tst single: gen_gal galsim.Gaussian(sigma=1.0, flux=100000.0)

# Now on a vmap which aims to build 2 gal profiles
JEC DBG: type self.sigma <class 'jax.interpreters.batching.BatchTracer'>
JEC DBG: gsobject tree_flatten
JEC DBG: gsobject tree_unflatten
JEC DBG: type self.sigma <class 'object'>
jecampagne commented 1 year ago

Some more input.

On the "drawing" branch

def gen_gal(p:dict):
  gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
  return gal.flux + gal.sigma

print("tst vmap gen_gal")
params = {'flux': jnp.array([1.e5, 2.e5]), 'sigma': jnp.array([1.,2.])}
res = jax.vmap(gen_gal)(params)
print(res)

gives as expected [100001. 200002.].

The reason that the mentionned vmap crash in the top of this issue is that gen_gal was returning gal and that triggers

JEC DBG tree_flatten
JEC DBG tree_unflatten

while it is not the case when one returns a "number" based on galaxy profiles.

NOW, it is not the end of the story as if one wants to return the image array of the galaxy profiles using:

def draw_gal_img(p:dict):
  #Define the galaxy profile.
  gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
  img = gal.drawImage(scale=0.2)
  return img.array

params = {'flux': jnp.array([1.e5, 2.e5]), 'sigma': jnp.array([1.,2.])}
imgs = jax.vmap(draw_gal_img)(params)
print("result: ", imgs)

gives pbs in the image.py L156 due to

   if ncol != int(ncol) or nrow != int(nrow): 

Then, if one comment this line as well as those using intfunction then on gets this Traceback

Traceback (most recent call last):
  File "test_vmap.py", line 33, in <module>
    imgs = jax.vmap(draw_gal_img)(params)
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/api.py", line 1682, in vmap_f
    out_flat = batching.batch(
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test_vmap.py", line 27, in draw_gal_img
    img = gal.drawImage(scale=0.2)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 479, in drawImage
    image = prof._setup_image(image, nx, ny, bounds, add_to_image, dtype, center)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 282, in _setup_image
    image = Image(N, N, dtype=dtype)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/image.py", line 160, in __init__
    self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/image.py", line 440, in _make_empty
    return jnp.zeros(shape=shape, dtype=dtype)
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2098, in zeros
    shape = canonicalize_shape(shape)
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 99, in canonicalize_shape
    return core.canonicalize_shape(shape, context)  # type: ignore
  File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/core.py", line 1899, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([ 62, 120], dtype=int64)
  batch_dim = 0, Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([ 62, 120], dtype=int64)
  batch_dim = 0).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The pb comes from the fact that the dimension of the image if not specified by the user, is a function of galaxy profile parameter. But

def draw_gal_img(p:dict):
  #Define the galaxy profile.
  gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
  img = gal.drawImage(nx=50, ny=50, scale=0.2)
  return img.array

gives

result:  [[[2.3794119e-08 6.2143016e-08 1.5593487e-07 ... 1.5593487e-07
   6.2143016e-08 2.3794119e-08]
  [6.2143016e-08 1.6229869e-07 4.0725453e-07 ... 4.0725453e-07
   1.6229869e-07 6.2143016e-08]
  [1.5593487e-07 4.0725453e-07 1.0219200e-06 ... 1.0219200e-06
   4.0725453e-07 1.5593487e-07]
  ...
  [1.5593487e-07 4.0725453e-07 1.0219200e-06 ... 1.0219200e-06
   4.0725453e-07 1.5593487e-07]
  [6.2143016e-08 1.6229869e-07 4.0725453e-07 ... 4.0725453e-07
   1.6229869e-07 6.2143016e-08]
  [2.3794119e-08 6.2143016e-08 1.5593487e-07 ... 1.5593487e-07
   6.2143016e-08 2.3794119e-08]]

 [[7.8704125e-01 1.0005255e+00 1.2592615e+00 ... 1.2592615e+00
   1.0005255e+00 7.8704125e-01]
  [1.0005255e+00 1.2719172e+00 1.6008351e+00 ... 1.6008351e+00
   1.2719172e+00 1.0005255e+00]
  [1.2592615e+00 1.6008351e+00 2.0148110e+00 ... 2.0148110e+00
   1.6008351e+00 1.2592615e+00]
  ...
  [1.2592615e+00 1.6008351e+00 2.0148110e+00 ... 2.0148110e+00
   1.6008351e+00 1.2592615e+00]
  [1.0005255e+00 1.2719172e+00 1.6008351e+00 ... 1.6008351e+00
   1.2719172e+00 1.0005255e+00]
  [7.8704125e-01 1.0005255e+00 1.2592615e+00 ... 1.2592615e+00
   1.0005255e+00 7.8704125e-01]]]
ismael-mendoza commented 1 year ago

Closed by #21