GalacticDynamics / unxt

Unitful Quantities in JAX
https://unxt.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
11 stars 3 forks source link

Grad through a quantity function #241

Open adrn opened 20 hours ago

adrn commented 20 hours ago

I tried this thinking it might work and was surprised that it throws an error -- is this a bug or not something we can support at the moment?

import jax
import astropy.units as u
from unxt import Quantity
from unxt import experimental

length = u.get_physical_type("length")
time = u.get_physical_type("time")
velocity = u.get_physical_type("velocity")

@jax.jit
def test_ad1(x: Quantity[length], t: Quantity[time]) -> Quantity[velocity]:
    return x / t

experimental.grad(test_ad1, units="si", argnums=1)(
    Quantity(15.0, u.m), Quantity(1.0, u.s)
)
...
UnitConversionError: 'm' (length) and 's' (time) are not convertible
nstarman commented 13 hours ago

We should figure out how to support this!

The following does already work.

experimental.grad(test_ad1, units=(u.m, u.s), argnums=1)(
    Quantity(15.0, u.m), Quantity(1.0, u.s)
)

But just "si" would be simpler.