pmelchior / scarlet2

Scarlet, all new and shiny
MIT License
13 stars 3 forks source link

Consistent treatment of coordinates #51

Closed pmelchior closed 2 months ago

pmelchior commented 4 months ago

We make implicit assumptions about the coordinate frame, in particular for source centers. When no WCS is set in the observed frame, they must be in pixel coordinates, but if a WCS is set, they may be in pixel coordinates or in sky coordinates. We assume it's the latter, but that may not be true and lead to confusions, e.g. this:

Originally posted by @charlotteaward in https://github.com/pmelchior/scarlet2/issues/47#issuecomment-2103497628

As we already depend on astropy for the WCS transformation, it would make sense to allow positions to be given as SkyCoord (in which case it will be transformed according to a WCS) or a jnp.array (in which case we assume it to be in pixel coordinates).

charlotteaward commented 4 months ago

Quick note that we also need to adjust plot.py, which assumes pixel coordinates for sources when adding labels!

pmelchior commented 4 months ago

After more conversation about this, we should also make sure that all operations involving the user can be made in either all WCS or all pixel coordinates. This is particularly important for multi-observation cases.

I suggest that we define all model coordinates in the frame of the sky. That means that center coordinates would be in RA/Dec, with step sizes in, say, arcsec; same for sizes of radial profiles. We can make use of the astropy.SkyCoord and astropy.units framework to automate conversions.

This will provide one important advantage: when a user defines the model frame, they can now simply say it's an abstract piece of sky, with a center and a size. From that perspective, it's not an image that is tacked onto the sky, which has routinely led to difficulties with new users. We, internally, compute images, but they represent that sky, so it makes sense to assume the coordinates of the sky, not of an image of the sky.

For this to work, we need to allow for two modes:

pmelchior commented 3 months ago

To further clarify our approach, I think we should allow the definition of model properties and parameters in sky coordinates, e.g.:

import astropy.units as u

center = astropy.coordinates.SkyCoord(ra, dec)
size = 1*u.arcsec
morph = GaussianMorphology(center, size)

# using the parameter framework from PR #56 
parameters += Parameter(morph.center, stepsize=0.01*u.arcsec)

For fitting/sampling we now have two choices:

  1. Convert sky to pixels, optimize in pixels, and then convert back. This could look like Scene._constraint_replace, run only at the beginning or the end of the optimization. This is straightforward, and we're doing this already in parts in the initialization code, but we have to catch instances like Parameter(morph.center, prior=numpyro.dist.Normal(morph.center, 1*u.arcsec), i.e. when the constraints or priors are specified in sky.
  2. Optimize in sky coordinates. Doing so would keep parameter values and constraint/priors aligned, but we have to make coordinates and their transformations differentiable.

Option 2 is much more work for very little gain. It might even be worse (try to optimize at both sides of RA=360). So, the trick will be to catch all instances of sky coordinates, convert them consistently to get into pixels for the optimizer/sampler, and then go back to present sky coordinates to the user.

There's a wrinkle. While Numpyro does not check arguments when declaring distributions, so you can create numpyro.dist.Normal(loc=SkyCoord(ra=10, dec=0, unit='deg'), scale=1*u.arcsec), it can't evaluate these distributions because it expects simple scalars or jnp.arrays. This may not be a problem because we convert these into pixel-based values before, right? Yes, but we have to find all occurrences. So, I suggest that when doing the coordinate transformations of Parameter fields, we should check all attributes for astropy units or coordinates, e.g.

for fieldname in ['node', 'constraint', 'prior', 'stepsize']:
  field = getattr(parameter, fieldname)
  for name in dir(field):
    attrib = getattrib(field, name)
    if isinstance(attrib, (u.Quantity, astropy.coordinates.SkyCoord)):
      attrib = frame.wcs_conversion_function(attrib)

This works because numpyro distributions do not protect their properties (like loc and scale for a Normal), so you can assign new values to them as it's done in the last line above.

pmelchior commented 3 months ago

And some more clarification on how we can best do this: As we will always compute models and their gradients in the pixel space, it's easiest to convert coordinates and distances at init time.

That means, a user can write:

center = SkyCoord(ra, dec)
PointSource(center, spectrum)

as long as we do the coordinate conversion to pixel of the model frame in PointSource.__init__ (or the method that is being called in __init__ to set the centers).

This means that we create an interface for the user to specify sky coordinates, and we'll immediately translate to pixels. Doing that just takes a menial task off the user, and ensures that the transformation is correctly applied (it's very easy to do that wrong, we've all done it...)

I think this provides a clear model definition: scarlet models describe a cube on the sky, whose location is specified by the model frame WCS. We simply make it easier for the user to define the model.

Leaves the Parameters. We can exploit that there is a custom __iadd__ for adding a Parameter instance to the Parameters list. At this point, we can modify the Parameter instance so that all fields holding sky coordinates and astropy units are converted to model pixels (with the code sketched above).