Closed EiffL closed 1 year ago
Hi @ismael-mendoza , thank you for your review! In my opinion, there is still an issue from these lines:
This jax numpy array conversion will break vmapping because self.x
and self.y
are converted to jnp.ndarray
so are not longer JAX tracers... This is what #33 is supposed to fix, by just removing these lines. The purpose of these lines were to be able to change the type of the ndarray
in PositionD
and PositionI
, e.g.
Something we cannot do because if we create a Position
object with float arguments, x.astype("int")
will break.
One solution could be to remove these type conversions and add warnings or errors in PositionD
and PositionI
if the arguments are not the expected type.
This being said, this PR passes the tests and is working for non-vmapped inputs. What I propose is to merge it and iterate on how to solve for vmapping in #33 . What do you think about that @ismael-mendoza , @EiffL ?
Hi @b-remy - thanks so much for your detailed response explaining the issue. I agree that this is still a problem in this PR exactly for the reason you pointed out.
One solution could be to remove these type conversions and add warnings or errors in PositionD and PositionI if the arguments are not the expected type.
I like this solution, and I also agree that perhaps we should merge this first and resolve this later in #33. Perhaps we should also add specific tests for vmapping?
Yes 100%, if you look at #33, you will see vmapping tests for all classes up to image.py
ah nice, I will take a closer look at #33 later then. I wonder if we should go ahead and merge (I know Francois is quite busy with DESC related things currently)
I agree to merge as it is all correct for float inputs. Let's continue to clean the PRs and fix vmapping separately.
nice, I will go ahead then...
This PR implements the
Position
object from GalSim: https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/galsim/position.py#L21 and adds a couple of related functionalities toGSObject
.The function
parse_pos_args
inutilities.py
is almost just a straight copy from GalSim, but instantiates the correct jax-galsim Position class.I don't think there is dedicated test file for positions in the GalSim test suite, but positions are used all over the place, so they get implicitly tested