Closed nstarman closed 8 months ago
So I think the error here is that youv'e written quaxify(jax.grad)(func)(...)
rather than quaxify(jax.grad(func))(...)
.
That is, quaxify
acts on a function-on-arrays (which jax.grad(func)
is) not on a function-on-functions (which jax.grad
is)
In particular you should never have to mess around with things like pytype_aval_mappings
-- making sure you never have to touch such internals is one of the design goals for Quax.
Regarding jpu
, this is an interesting idea! I hadn't really looked into the idea of using Quax to handle quantities, but you're totally right, I think this is probably doable.
Thanks! I had originally tried quaxify(jax.grad(func))
but ran into problems and then tried quaxify(jax.grad)(func)
, which seemed to get further. Returning to quaxify(jax.grad(func))
, the problem is an interesting one.
Assuming #6, consider
import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify
jax.config.update("jax_enable_x64", True)
x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")
def func(q: Quantity) -> Quantity:
return 5 * q**3
out = quaxify(jax.grad(func))(q[0])
out.value, out.unit
> (Array(15., dtype=float64), Unit("m2"))
This works perfectly! 🎉
The problem arises when func
has Quantity
internal to func
and the input can't be stripped of units beforehand and then have the units re-attached. For example.
def func(q: Quantity) -> Quantity:
return 5 * q**3 + Quantity(jnp.array(1.0), "m3")
out = quaxify(jax.grad(func))(q[0])
> ValueError: Cannot add a non-quantity and quantity.
This is my error message in
@register(lax.add_p)
def _add_p_vq(x: DenseArrayValue, y: Quantity) -> Quantity:
# x = 0 is a special case
if jnp.array_equal(x, 0):
return y
# otherwise we can't add a quantity to a normal value
raise ValueError("Cannot add a non-quantity and quantity.")
What appears to be happening is that 5 * q ** 3
is coming out as a DenseArrayValue
, not a Quantity
. I'm not sure how this is happening since if I add jax.debug.print
to func
I don't see this occurring
def func(q: Quantity) -> Quantity:
jax.debug.print("q {}, {}, {}", type(q), type(q.primal), type(q.primal.value))
jax.debug.print("5 q ** 3: {}", type((5 * q**3).primal.value))
return 5 * q**3 + Quantity(1.0, unit="m3")
out = quaxify(jax.grad(func))(q[0])
> q <class 'jax._src.interpreters.ad.JVPTracer'>, <class 'quax._core._QuaxTracer'>, <class 'jax_quantity._core.Quantity'>
> 5 q ** 3: <class 'jax_quantity._core.Quantity'>
If 5 q ** 3
is a Quantity then the lax.add
that should be used is
@register(lax.add_p)
def _add_p_qq(x: Quantity, y: Quantity) -> Quantity:
return Quantity(lax.add(x.to_value(x.unit), y.to_value(x.unit)), unit=x.unit)
However, this doesn't appear to be the case.
Hopefully if this can be solved then grad
would work! Thanks.
Sorry for the noise.
Looking into the stack trace, I suspect I got the 5 q **3
diagnosis wrong. Jax's binop flip is being called, leading me to believe it's actually the Quantity(jnp.array(1.0), "m3")
in 5 * q**3 + Quantity(jnp.array(1.0), "m3")
that is being turned into a non-Quantity DenseArrayValue
. This is further corroborated by a test where I construct the addition differently and explicitly call lax.add
to circumvent the radd
binop flip of add
.
constant = Quantity(1.0, unit="m2")
@quaxify
def func(q: Quantity) -> Quantity:
jax.debug.print("5 q ** 3: {}", (5 * q**3).value)
jax.debug.print("c * q: {}", type((constant * q).value))
return jax.lax.add(5 * q**3, constant * q)
out = func(q[0])
out.value, out.unit
> 5 q ** 3: Quantity(value=f64[], unit=Unit("m3"))
> c * q: <class 'quax._core._QuaxTracer'>
> UnitConversionError: ... volume and area can't be added # (my words)
So the Quantity multiplication is strange, with the value being a QuaxTracer and the resulting unit being m^2, not m^3. I have https://github.com/nstarman/jax-quantity/blob/5a9d082b51b06c5b0467b18a8175cad183a4e89c/src/jax_quantity/_register_primitives.py#L887-L900, which appears to be correct, but must subtly misunderstand how to override the primitive when jax passes Tracers through the operations.
Hmm, it's not clear to me if you have a question or if you've resolved it?
FWIW, I will comment that you are doing something that I'm still not completely happy with Quax's behaviour for, which is returning custom array-ish values from primitive rules. This is fine when working with a function (Quantity, DenseArrayValue) -> Quantity
, for which we know the exact types.
But it's not really clear how it should be defined for something like (Quantity, ArrayValue) -> Quantity
, where the ArrayValue
could be any concrete type at all. Let me explain a little bit.
Problem 1: handling nested values
For example maybe the ArrayValue
happens to be a hypothetical SparseArrayValue
-- and in this case we may like the implementation to redispatch (using Quax again!) to a (DenseArrayValue, SparseArrayValue) -> SparseArrayValue
operation, and then wrap the return back into a Quantity
. So that the overall return value is Quantity(SparseArrayValue(...), ...)
.
That's now pretty weird: when you wrote Quantity
you appear to have assumed that it's wrapping specifically and only a raw JAX array:
not a SparseArrayValue
. This could cause all kinds of bugs later down the line.
Problem 2: easy bugs in implementations
As another example of the problems this causes, consider this line:
in which you've written lax.ne(x.value, y)
-- but y
is an ArrayValue
, for which this operation may raise an error. You probably meant quaxify(lax.ne)(x.value, y)
. That's a bit annoying, having to put a quaxify
on literally every JAX operation inside a Quax rule.
Right now I'm still thinking about a plan to fix this. I'm not sure exactly what that will be yet, though! So I'd welcome any thoughts on what might be a nice way to do this. The end goal will probably be to end up with something Julia-like, where you can just nest array-ish values freely without worrying too much; possibly this can be accomplished by always putting Quax at the bottom of the interpreter stack. Details very much TBD; you have been warned :)
Hmm, it's not clear to me if you have a question or if you've resolved it?
you're right either way 😆. I've half-solved the problem. Functions that don't internally construct and operate on a Quantity work great. Functions that do still raise an error.
FWIW, I will comment that you are doing something that I'm still not completely happy with Quax's behaviour for, which is returning custom array-ish values from primitive rules. This is fine when working with a function
(Quantity, DenseArrayValue) -> Quantity
, for which we know the exact types.
Thanks! I'll change my annotations to DenseArrayValue
for now. I would like to be able to work on custom array-ish values, but Jax arrays and Quantity are good for now.
@patrick-kidger, I found the root of the problem and have a question about how to best resolve the issue.
Consider this MWE (a simplified form of Quantity
)
from typing import Self
import jax
import quax
from jax import lax
class MyArray(quax.ArrayValue):
value: jax.Array
unit: str
@property
def shape(self) -> tuple[int, ...]:
"""Shape of the array."""
return self.value.shape
def materialise(self) -> None:
raise RuntimeError("Refusing to materialise `MyArray`.")
def aval(self) -> jax.core.ShapedArray:
return jax.core.get_aval(self.value)
@quax.register(lax.mul_p)
def _(x: MyArray, y: MyArray) -> MyArray:
unit = f"{x.unit}*{y.unit}"
return MyArray(lax.mul(x.value, y.value), unit=unit)
If I define a function func
then everything works as expected
def func(q: MyArray) -> MyArray:
c = MyArray(2.0, unit="m2")
jax.debug.print("mutiplying {} * {}", type(c), type(q))
return c * q
x = MyArray(1.5, unit="m")
out = func(x)
> mutiplying <class '__main__.MyArray'> * <class '__main__.MyArray'>
> MyArray(value=f64[], unit='m2*m')
However if this function is quaxify
'd then it will fail
quaxfunc = quax.quaxify(func)
out = quaxfunc(x)
> mutiplying <class '__main__.MyArray'> * <class 'quax._core._QuaxTracer'>
The issue is that x
in quaxfunc
is wrapped by the QuaxTracer
(x.array.value
is MyArray
) but c
remains a MyArray
.
I think this is the issue underlying quaxify(grad(func))
.
I can hack around this by defining
@quax.register(lax.mul_p)
def _(x: MyArray, y: quax.DenseArrayValue) -> MyArray:
assert isinstance(y.array, quax._core._QuaxTracer)
assert isinstance(y.array.value, MyArray)
unit = f"{x.unit}*{y.array.value.unit}"
return MyArray(lax.mul(x.value, y.array.value.value), unit=unit)
I don't see a comparable scenario in quax
atm. Thanks for the help!
This looks expected. When you wrap func
into quaxify(func)
, then you are turning a function that acts on arrays into a function that acts on quax.ArrayValue
s. To better understand this, try printing out the jaxpr for the original function:
jaxpr = jax.make_jaxpr(func)(x)
print(jaxpr)
this defines a sequence of primitive operations, turning input into output. Then the Quaxified version -- quaxify(func)
-- will iterate through every equation in the jaxpr, looking for a registered rule for that primitive, for the ArrayValue
s used as inputs.
So in this case, the quaxify
means that the input to func
is being thought of as a regular array rather than as a MyArray
, and that's why you hit the (:MyArray, :DenseArrayValue)
rule.
Here, what you probably want to do is just not apply the quaxify
wrapper, as what you have already works without it.
(FWIW this is a complexity I'm hoping will be tidied up when I rewrite things a bit, as per my previous comment.)
Here, what you probably want to do is just not apply the quaxify wrapper, as what you have already works without it
Unfortunately since that function was a MWE of quaxify(grad(func))
applying the quaxify
wrapper is a necessity.
So in this case, the quaxify means that the input to func is being thought of as a regular array rather than as a MyArray, and that's why you hit the (:MyArray, :DenseArrayValue) rule.
Thanks for the diagnosis! what is the proper way to get the MyArray
from inside the DenseArrayValue
? I tried x.array.value
(requiring adding if/else
statement logic to the overrides) and this works for the quaxify
decorator, but grad
then changes x.array
to a JVPTracer
.
My hacking around is in https://github.com/GalacticDynamics/jax-quantity/pull/4.
import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify
jax.config.update("jax_enable_x64", True)
x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")
def func(q: Quantity) -> Quantity:
return Quantity(2.0, unit="m2") * q
out = func(q[0])
print(out.value, out.unit)
> (Array(2., dtype=float64), Unit("m3"))
out = quaxify(jax.grad(func))(q[2])
print(out.value, out.unit)
> TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m3")).
I think the way I deconstructed JVPTracer
is leading to this problem. The multiplication is coming in as mul(:Quantity, :DenseArrayValue)
. I pass this to mul(:Quantity, :Quantity)
by taking y.array.primal.value
on the latter argument.
Okay, returning to this issue! It has not been forgotten about...
As above, this discussion has prompted me to do some harding thinking on the design choices for Quax. The just-released v0.0.3 release should hopefully straighten things out a bit. It's very much a breaking release (!), but if it seems to work then hopefully we can standardise on this, and start building libraries on top of Quax in earnest.
With respect to topics discussed in this issue:
DenseArrayValue
has been removed. We just interact directly with normal arrays now. Hopefully this simplification helps avoid issues there.Quantity(2, unit="m") * x
-- all Quax types must be passed across a quax.quaxify
boundary first. On the one hand this simplifies reasoning about Quax a lot, removes some bad footguns, and hopefully avoids some of the confusion you have here. On the other hand, I can see that this might be a bit fiddly when you need to pipe e.g. gravitational constants through as arguments (see the g = Unitful(jnp.array(9.81), {meters: 1, seconds: -2})
in the above example), rather than just creating them on-the-fly. Unfortunately I couldn't see any way to have a principled design that made this work. I'm definitely open to suggestions on this front.Thanks! grad
works great when all inputs are format arguments.
In https://github.com/GalacticDynamics/jax-quantity/pull/4 I'm trying to get
jax.grad
to work on functions that acceptQuantity
arguments, and have run into some difficulties.The following doesn't work,
returning an error
TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m2")).
This error was expected sincegrad
checks for scalar outputs (withjax._src.api._check_scalar
). The underlying issue appeared to be that_check_scalar
callsconcrete_aval
, which errors onQuantity
.quax
-compatible classes have anaval()
method so I hooked that up to a handler and registered it intopytype_aval_mappings
While this gets a few lines further in
grad
, unfortunately this causes a disagreement between pytree structures with the errorTypeError: Tree structure of cotangent input PyTreeDef(*), does not match structure of primal output PyTreeDef(CustomNode(Quantity[('value',), ('unit',), (Unit("m2"),)], [*])).
. I haven't figured out how to fix this issue. Any suggestions would be appreciated!p.s. @dfm has figured out how to do
grad
onQuantity
injpu
by shunting the units toaux
data and re-assembling after. This solution works well, but it's a solution unique to Quantity, requiring a customgrad
function. I was hoping to get this working withquaxify
in a way that didn't require in https://github.com/GalacticDynamics/array-api-jax-compat dispatching usingplum
to library-specificgrad
implementations (especially since it's not obvious on what to dispatch to mapfunc
to theQuantity
implementation).