GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
28 stars 3 forks source link

Adds jittable Position object #11

Closed EiffL closed 1 year ago

EiffL commented 2 years ago

This PR implements the Position object from GalSim: https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/galsim/position.py#L21 and adds a couple of related functionalities to GSObject.

The function parse_pos_args in utilities.py is almost just a straight copy from GalSim, but instantiates the correct jax-galsim Position class.

I don't think there is dedicated test file for positions in the GalSim test suite, but positions are used all over the place, so they get implicitly tested

b-remy commented 1 year ago

Hi @ismael-mendoza , thank you for your review! In my opinion, there is still an issue from these lines:

https://github.com/GalSim-developers/JAX-GalSim/blob/f38ba4cb8b1f753fd66343a9ecb559d24d62f3e3/jax_galsim/position.py#L49-L51

This jax numpy array conversion will break vmapping because self.x and self.y are converted to jnp.ndarray so are not longer JAX tracers... This is what #33 is supposed to fix, by just removing these lines. The purpose of these lines were to be able to change the type of the ndarray in PositionD and PositionI, e.g.

https://github.com/GalSim-developers/JAX-GalSim/blob/f38ba4cb8b1f753fd66343a9ecb559d24d62f3e3/jax_galsim/position.py#L146-L150

Something we cannot do because if we create a Position object with float arguments, x.astype("int") will break.

One solution could be to remove these type conversions and add warnings or errors in PositionD and PositionI if the arguments are not the expected type.

This being said, this PR passes the tests and is working for non-vmapped inputs. What I propose is to merge it and iterate on how to solve for vmapping in #33 . What do you think about that @ismael-mendoza , @EiffL ?

ismael-mendoza commented 1 year ago

Hi @b-remy - thanks so much for your detailed response explaining the issue. I agree that this is still a problem in this PR exactly for the reason you pointed out.

One solution could be to remove these type conversions and add warnings or errors in PositionD and PositionI if the arguments are not the expected type.

I like this solution, and I also agree that perhaps we should merge this first and resolve this later in #33. Perhaps we should also add specific tests for vmapping?

b-remy commented 1 year ago

Yes 100%, if you look at #33, you will see vmapping tests for all classes up to image.py

ismael-mendoza commented 1 year ago

ah nice, I will take a closer look at #33 later then. I wonder if we should go ahead and merge (I know Francois is quite busy with DESC related things currently)

b-remy commented 1 year ago

I agree to merge as it is all correct for float inputs. Let's continue to clean the PRs and fix vmapping separately.

ismael-mendoza commented 1 year ago

nice, I will go ahead then...