Closed jecampagne closed 1 year ago
More
# When I call to instantiate a single gaussian profile
JEC DBG: type self.sigma <class 'float'>
tst single: gen_gal galsim.Gaussian(sigma=1.0, flux=100000.0)
# Now on a vmap which aims to build 2 gal profiles
JEC DBG: type self.sigma <class 'jax.interpreters.batching.BatchTracer'>
JEC DBG: gsobject tree_flatten
JEC DBG: gsobject tree_unflatten
JEC DBG: type self.sigma <class 'object'>
Some more input.
On the "drawing" branch
def gen_gal(p:dict):
gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
return gal.flux + gal.sigma
print("tst vmap gen_gal")
params = {'flux': jnp.array([1.e5, 2.e5]), 'sigma': jnp.array([1.,2.])}
res = jax.vmap(gen_gal)(params)
print(res)
gives as expected [100001. 200002.]
.
The reason that the mentionned vmap crash in the top of this issue is that gen_gal was returning gal
and that triggers
JEC DBG tree_flatten
JEC DBG tree_unflatten
while it is not the case when one returns a "number" based on galaxy profiles.
NOW, it is not the end of the story as if one wants to return the image array of the galaxy profiles using:
def draw_gal_img(p:dict):
#Define the galaxy profile.
gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
img = gal.drawImage(scale=0.2)
return img.array
params = {'flux': jnp.array([1.e5, 2.e5]), 'sigma': jnp.array([1.,2.])}
imgs = jax.vmap(draw_gal_img)(params)
print("result: ", imgs)
gives pbs in the image.py
L156 due to
if ncol != int(ncol) or nrow != int(nrow):
Then, if one comment this line as well as those using int
function then on gets this Traceback
Traceback (most recent call last):
File "test_vmap.py", line 33, in <module>
imgs = jax.vmap(draw_gal_img)(params)
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/api.py", line 1682, in vmap_f
out_flat = batching.batch(
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "test_vmap.py", line 27, in draw_gal_img
img = gal.drawImage(scale=0.2)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 479, in drawImage
image = prof._setup_image(image, nx, ny, bounds, add_to_image, dtype, center)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 282, in _setup_image
image = Image(N, N, dtype=dtype)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/image.py", line 160, in __init__
self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/image.py", line 440, in _make_empty
return jnp.zeros(shape=shape, dtype=dtype)
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2098, in zeros
shape = canonicalize_shape(shape)
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 99, in canonicalize_shape
return core.canonicalize_shape(shape, context) # type: ignore
File "/sps/lsst/users/campagne/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/core.py", line 1899, in canonicalize_shape
raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
val = DeviceArray([ 62, 120], dtype=int64)
batch_dim = 0, Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
val = DeviceArray([ 62, 120], dtype=int64)
batch_dim = 0).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The pb comes from the fact that the dimension of the image if not specified by the user, is a function of galaxy profile parameter. But
def draw_gal_img(p:dict):
#Define the galaxy profile.
gal = galsim.Gaussian(flux=p['flux'], sigma=p['sigma'])
img = gal.drawImage(nx=50, ny=50, scale=0.2)
return img.array
gives
result: [[[2.3794119e-08 6.2143016e-08 1.5593487e-07 ... 1.5593487e-07
6.2143016e-08 2.3794119e-08]
[6.2143016e-08 1.6229869e-07 4.0725453e-07 ... 4.0725453e-07
1.6229869e-07 6.2143016e-08]
[1.5593487e-07 4.0725453e-07 1.0219200e-06 ... 1.0219200e-06
4.0725453e-07 1.5593487e-07]
...
[1.5593487e-07 4.0725453e-07 1.0219200e-06 ... 1.0219200e-06
4.0725453e-07 1.5593487e-07]
[6.2143016e-08 1.6229869e-07 4.0725453e-07 ... 4.0725453e-07
1.6229869e-07 6.2143016e-08]
[2.3794119e-08 6.2143016e-08 1.5593487e-07 ... 1.5593487e-07
6.2143016e-08 2.3794119e-08]]
[[7.8704125e-01 1.0005255e+00 1.2592615e+00 ... 1.2592615e+00
1.0005255e+00 7.8704125e-01]
[1.0005255e+00 1.2719172e+00 1.6008351e+00 ... 1.6008351e+00
1.2719172e+00 1.0005255e+00]
[1.2592615e+00 1.6008351e+00 2.0148110e+00 ... 2.0148110e+00
1.6008351e+00 1.2592615e+00]
...
[1.2592615e+00 1.6008351e+00 2.0148110e+00 ... 2.0148110e+00
1.6008351e+00 1.2592615e+00]
[1.0005255e+00 1.2719172e+00 1.6008351e+00 ... 1.6008351e+00
1.2719172e+00 1.0005255e+00]
[7.8704125e-01 1.0005255e+00 1.2592615e+00 ... 1.2592615e+00
1.0005255e+00 7.8704125e-01]]]
Closed by #21
Hi,
Using the drawing branch I experience some problems that I expose here. Some can find a temporary workaround but it might be difficult for the rest of the soft development.
Let us try to make 2 galaxy profiles using a vmap on their flux & sigma parameters:
the final error message is
See Here a nb on Colab