GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
25 stars 3 forks source link

main (drawing) branch: vmap on catalog crash #31

Closed jecampagne closed 1 year ago

jecampagne commented 1 year ago

Here a report on an catalog of Gaussian profiles creation and a simple vmap crash.

  1. Catalog using “the struct of arrays pattern”
fluxes = jnp.array([100000., 200000.]); sigmas = jnp.array([1.,2.])
catalog = galsim.Gaussian(flux=fluxes,sigma=sigmas)
catalog 

gives

galsim.Gaussian(sigma=DeviceArray([1., 2.], dtype=float64), flux=DeviceArray([100000., 200000.], dtype=float64), gsparams=galsim.GSParams(128,8192,0.005,5,0.001,1e-05,1e-05,1,0.0001,1e-06,1e-06,1e-08,1e-05))
  1. Now vmap a simple function
    jax.vmap(lambda x: x.maxk)(catalog)

    leads to

    
    /usr/local/lib/python3.8/dist-packages/jax_galsim/gsobject.py in tree_unflatten(cls, aux_data, children)
    556     def tree_unflatten(cls, aux_data, children):
    557         """Recreates an instance of the class from flatten representation"""
    --> 558         return cls(**(children[0]), **aux_data)

/usr/local/lib/python3.8/dist-packages/jax_galsim/gaussian.py in init(self, half_light_radius, sigma, fwhm, flux, gsparams) 64 super().init(sigma=sigma, flux=flux, gsparams=gsparams) 65 ---> 66 self._sigsq = self.sigma*2 67 self._inv_sigsq = 1.0 / self._sigsq 68 self._norm = self.flux self._inv_sigsq * Gaussian._inv_twopi

TypeError: unsupported operand type(s) for ** or pow(): 'object' and 'int'

ismael-mendoza commented 1 year ago

Closed by #21