LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
42 stars 6 forks source link

Cannot compute gradients for scalar parameters #271

Open h-greer opened 5 days ago

h-greer commented 5 days ago

It appears that Zodiax doesn't compute gradients for any source or aperture parameters which are scalar values (i.e. floats). This causes a lot of parameter inference stuff (both gradient optimisation and HMC) to break in non-obvious ways.

Here's a minimal example that reproduces the bug. It sets up a simple telescope with a circular aperture and a point source, and then tries to compute the gradient wrt. the flux of the point source and print it. Running this on the latest dLux produces an output of None, which indicates that a gradient wasn't computed, rather than the expected output of [0.0]. Changing the path variable to be "source.position" (which is a JAX array of length 2) will cause a successful computation of the gradient, but "aperture.radius" (which is again a float) will not.

import jax.numpy as np
import jax.scipy as jsp

import zodiax as zdx

import dLux as dl

npix = 64
diam = 2

aperture = dl.layers.CircularAperture(radius=diam/2, transformation=dl.CoordTransform(), occulting=True, normalise=True)

optics = dl.AngularOpticalSystem(npix, diam, [("aperture",aperture)], 64, 50e-3, 1)

wavelengths = np.asarray([1.2e-6])
weights = np.asarray([1])

source = dl.PointSource(wavelengths=wavelengths)

telescope = dl.Telescope(optics, ("binary",source))

data = telescope.model()

path = "source.flux"

@zdx.filter_value_and_grad(path)
def loss_fn(model, data):
    out = model.model()
    return -np.sum(jsp.stats.poisson.logpmf(data, out))

loss, grads = loss_fn(telescope, data)

print(grads.get(path))

I believe the issue arises from changes to the __init__ method for the sources and apertures introduced in #246. Taking the point source constructor as an example (docstring removed for clarity), the scalar flux input is typecast to a float. I think this is the root cause of the issue, since JAX cannot trace through the float call, and thus can't compute the gradient. Changing the self.flux line to match the self.position line rectifies the issue.

def __init__(
        self: Source,
        wavelengths: Array = None,
        position: Array = np.zeros(2),
        flux: float = 1.0,
        weights: Array = None,
        spectrum: Spectrum() = None,
    ):
        # Position and Flux
        self.position = np.asarray(position, dtype=float)
        self.flux = float(flux)

        if self.position.shape != (2,):
            raise ValueError("position must be a 1d array of shape (2,).")

        super().__init__(
            wavelengths=wavelengths, weights=weights, spectrum=spectrum
        )

If you think this fix is sound, I'm happy to make a PR fixing all of the occurrences of the typecast.

h-greer commented 5 days ago

probably worth noting this is a superset of the issue reported in #264 (since it affects all the apertures and sources). I don't think a Zodiax-side fix as mentioned there will work out, because the float() call will strip out JAX information. This can be tested with a very simple JAX-only example

from jax import grad
import jax.np as jnp

@grad
def f(x):
    return float(x)

@grad
def f2(x):
    return jnp.asarray(x, float)

print(f(1.0)) # fails with ConcretizationTypeError

print(f2(1.0)) # outputs 1.0 as expected
LouisDesdoigts commented 5 days ago

Yeah this is a known issue as you point out - I would prefer to keep the dLux side as it is, since having a float leaf allows us to easily check the its actual value when printing (rather than just the shape and dtype).

The true fix is to inject some code I've been using locally:

def set_array(pytree, parameters):
    dtype = np.float64 if config.x64_enabled else np.float32
    floats, other = eqx.partition(pytree, eqx.is_inexact_array_like)
    floats = jtu.tree_map(lambda x: np.array(x, dtype=dtype), floats)
    return eqx.combine(floats, other)

into the inner_wrapper function into both of the Zodiax filter functions https://github.com/LouisDesdoigts/zodiax/blob/main/zodiax/eqx.py which is then called on any object we want to take gradients with respect to.

There's a series of small fixes and improvements like these that have been hanging around locally while I've been working on other projects. You can either just use this fix locally or submit a PR into Zodiax if you'd prefer.

h-greer commented 5 days ago

I would prefer to keep the dLux side as it is, since having a float leaf allows us to easily check the its actual value when printing (rather than just the shape and dtype).

That's a pretty compelling reason, and I would agree that fixing it in Zodiax is a much cleaner solution (I just didn't realise that was possible). Once I get it working locally, I'll make a PR

h-greer commented 5 days ago

@LouisDesdoigts how exactly have you been using that set_array function? When I try to use it, I get a TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'

LouisDesdoigts commented 5 days ago

So this function is designed to be used with a set of input parameter paths as defined in Zodiax. So once we have defined what leaves we want to take gradients with respect to ie params = ['diameter', 'flux', 'coefficients'] we can the call the set_array function on our model and it should return the same object but with those leaves enforced to be array of the correct dtype. ie:

optics = ... 

params = ['diameter', 'flux', 'coefficients']

optics = set_array(optics, params) # This object should always return gradients for all leaves defined by params

However judging by that error message I suspect that you are pointing to a leaf that is set as None. I should also note that this method is designed to use the same set of params that is passed into the zodiax filter functions (filter_grad, filter_value_and_grad)

Could you provide a minimum working example (MWE)? Its hard to say much more with a single error that I can't reproduce.

h-greer commented 5 days ago

It looks like I was using the function incorrectly, then. I'd tried putting it inside the Zodiax filter functions directly, but that broke the optimiser with that type error when it tries to combine an array-ified gradient with an unmodified system.

LouisDesdoigts commented 5 days ago

Ultimately it should live within the zodiax filter function, but applying before being passed to the filter function was just my hacky way around not modifying zodiax.

Looking back over the function I sent, it actually isn't doing exactly what I said. It actually has no dependence on the parameters input (it was an earlier version that did this) and simply sets everything that it can to an array of the correct dtype to avoid issues with re-compilation arising from different precision of various leaves. That said this method should actually be more robust, it can't change any of the behavior of the object it returns, as an array type is a strict super-set of a float type.

If your error is now arising from an interaction with the optimiser object (I assume an optax object?) its possible optax strictly enforces exact dtypes. The quick test for this would be to call the set_array function before passing it to any optimisers.

h-greer commented 5 days ago

I've found the cause of the optimiser error, there was an is_array instead of is_inexact_array_like inside zodiax.get_optimiser that was obliterating the floats before they even made it into optax. Changing that and using set_array inside the filter functions seems to work without errors. I'll double-check that I haven't made any mistakes and make a PR in Zodiax tomorrow

LouisDesdoigts commented 4 days ago

Ah yes that would do it, good find!

I've actually overhauled that function in my local work and various other things in zodiax (its in need of some tender love and care tbh), so that function is somewhat out of date, but all improvements are welcome!