Open adrn opened 20 hours ago
I tried this thinking it might work and was surprised that it throws an error -- is this a bug or not something we can support at the moment?
import jax import astropy.units as u from unxt import Quantity from unxt import experimental length = u.get_physical_type("length") time = u.get_physical_type("time") velocity = u.get_physical_type("velocity") @jax.jit def test_ad1(x: Quantity[length], t: Quantity[time]) -> Quantity[velocity]: return x / t experimental.grad(test_ad1, units="si", argnums=1)( Quantity(15.0, u.m), Quantity(1.0, u.s) )
... UnitConversionError: 'm' (length) and 's' (time) are not convertible
We should figure out how to support this!
The following does already work.
experimental.grad(test_ad1, units=(u.m, u.s), argnums=1)( Quantity(15.0, u.m), Quantity(1.0, u.s) )
But just "si" would be simpler.
I tried this thinking it might work and was surprised that it throws an error -- is this a bug or not something we can support at the moment?