LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
43 stars 6 forks source link

PointSource enforces flux as float #264

Open maxecharles opened 7 months ago

maxecharles commented 7 months ago

Related to this Zodiax issue, where the gradients of a float returns None.

dl.PointSource enforces flux to be a float.

"""the dLux source code"""
class PointSource(Source):
    ...
    def __init__(...):
        # Position and Flux
        self.position = np.asarray(position, dtype=float)
        self.flux = float(flux)

        if self.position.shape != (2,):
            raise ValueError("position must be a 1d array of shape (2,).")

This issue was fixed by setting it to a jax array in the same fashion as the position:

"""my alteration of the dLux source code"""
class PointSource(Source):
    ...
    def __init__(...):
        # Position and Flux
        self.position = np.asarray(position, dtype=float)
        self.flux = np.asarray(flux, dtype=float)

        if self.position.shape != (2,):
            raise ValueError("position must be a 1d array of shape (2,).")
        if self.flux.shape != ():
            raise ValueError("flux must be a scalar value.")

I can push this if need be or would it be better just to fix it in Zodiax?

LouisDesdoigts commented 7 months ago

This would be a Zodiax side fix, probably either by pre-processing the leaves into arrays or using a different equinox filter function. Will need to check how it interacts with jax.grad calls though to make a final call