patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
102 stars 3 forks source link

`quaxify` on a `jax.grad` #5

Closed nstarman closed 8 months ago

nstarman commented 10 months ago

In https://github.com/GalacticDynamics/jax-quantity/pull/4 I'm trying to get jax.grad to work on functions that accept Quantity arguments, and have run into some difficulties.

The following doesn't work,

import jax
import jax.numpy as jnp
from jax_quantity import Quantity
jax.config.update("jax_enable_x64", True)

x = jnp.array([1, 2, 3], dtype=jax_xp.float64)
q = Quantity(x, unit="m")

def func(q) -> Quantity:
    return 5 * q**2 + Quantity(1.0, unit="m2")

quaxify(jax.grad)(func)(q[0])

returning an error TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m2")). This error was expected since grad checks for scalar outputs (with jax._src.api._check_scalar). The underlying issue appeared to be that _check_scalar calls concrete_aval, which errors on Quantity. quax-compatible classes have an aval() method so I hooked that up to a handler and registered it into pytype_aval_mappings

jax._src.core.pytype_aval_mappings[Quantity] = lambda q: q.aval()

quaxify(jax.grad)(func)(q[0])

While this gets a few lines further in grad, unfortunately this causes a disagreement between pytree structures with the error TypeError: Tree structure of cotangent input PyTreeDef(*), does not match structure of primal output PyTreeDef(CustomNode(Quantity[('value',), ('unit',), (Unit("m2"),)], [*])).. I haven't figured out how to fix this issue. Any suggestions would be appreciated!

p.s. @dfm has figured out how to do grad on Quantity in jpu by shunting the units to aux data and re-assembling after. This solution works well, but it's a solution unique to Quantity, requiring a custom grad function. I was hoping to get this working with quaxify in a way that didn't require in https://github.com/GalacticDynamics/array-api-jax-compat dispatching using plum to library-specific grad implementations (especially since it's not obvious on what to dispatch to map func to the Quantity implementation).

patrick-kidger commented 10 months ago

So I think the error here is that youv'e written quaxify(jax.grad)(func)(...) rather than quaxify(jax.grad(func))(...).

That is, quaxify acts on a function-on-arrays (which jax.grad(func) is) not on a function-on-functions (which jax.grad is)

In particular you should never have to mess around with things like pytype_aval_mappings -- making sure you never have to touch such internals is one of the design goals for Quax.

Regarding jpu, this is an interesting idea! I hadn't really looked into the idea of using Quax to handle quantities, but you're totally right, I think this is probably doable.

nstarman commented 10 months ago

Thanks! I had originally tried quaxify(jax.grad(func)) but ran into problems and then tried quaxify(jax.grad)(func), which seemed to get further. Returning to quaxify(jax.grad(func)), the problem is an interesting one.

Assuming #6, consider

import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify

jax.config.update("jax_enable_x64", True)

x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")

def func(q: Quantity) -> Quantity:
    return 5 * q**3

out = quaxify(jax.grad(func))(q[0])
out.value, out.unit
> (Array(15., dtype=float64), Unit("m2"))

This works perfectly! 🎉

The problem arises when func has Quantity internal to func and the input can't be stripped of units beforehand and then have the units re-attached. For example.

def func(q: Quantity) -> Quantity:
    return 5 * q**3 + Quantity(jnp.array(1.0), "m3")

out = quaxify(jax.grad(func))(q[0])
> ValueError: Cannot add a non-quantity and quantity.

This is my error message in

@register(lax.add_p)
def _add_p_vq(x: DenseArrayValue, y: Quantity) -> Quantity:
    # x = 0 is a special case
    if jnp.array_equal(x, 0):
        return y
    # otherwise we can't add a quantity to a normal value
    raise ValueError("Cannot add a non-quantity and quantity.")

What appears to be happening is that 5 * q ** 3 is coming out as a DenseArrayValue, not a Quantity. I'm not sure how this is happening since if I add jax.debug.print to func I don't see this occurring

def func(q: Quantity) -> Quantity:
    jax.debug.print("q {}, {}, {}", type(q), type(q.primal), type(q.primal.value))
    jax.debug.print("5 q ** 3: {}", type((5 * q**3).primal.value))
    return 5 * q**3 + Quantity(1.0, unit="m3")

out = quaxify(jax.grad(func))(q[0])
> q <class 'jax._src.interpreters.ad.JVPTracer'>, <class 'quax._core._QuaxTracer'>, <class 'jax_quantity._core.Quantity'>
> 5 q ** 3: <class 'jax_quantity._core.Quantity'>

If 5 q ** 3 is a Quantity then the lax.add that should be used is

@register(lax.add_p)
def _add_p_qq(x: Quantity, y: Quantity) -> Quantity:
    return Quantity(lax.add(x.to_value(x.unit), y.to_value(x.unit)), unit=x.unit)

However, this doesn't appear to be the case. Hopefully if this can be solved then grad would work! Thanks.

nstarman commented 10 months ago

Sorry for the noise. Looking into the stack trace, I suspect I got the 5 q **3 diagnosis wrong. Jax's binop flip is being called, leading me to believe it's actually the Quantity(jnp.array(1.0), "m3") in 5 * q**3 + Quantity(jnp.array(1.0), "m3") that is being turned into a non-Quantity DenseArrayValue. This is further corroborated by a test where I construct the addition differently and explicitly call lax.add to circumvent the radd binop flip of add.

constant = Quantity(1.0, unit="m2")

@quaxify
def func(q: Quantity) -> Quantity:
    jax.debug.print("5 q ** 3: {}", (5 * q**3).value)
    jax.debug.print("c * q: {}", type((constant * q).value))
    return jax.lax.add(5 * q**3, constant * q)

out = func(q[0])
out.value, out.unit
> 5 q ** 3: Quantity(value=f64[], unit=Unit("m3"))
> c * q: <class 'quax._core._QuaxTracer'>
> UnitConversionError: ... volume and area can't be added  # (my words)

So the Quantity multiplication is strange, with the value being a QuaxTracer and the resulting unit being m^2, not m^3. I have https://github.com/nstarman/jax-quantity/blob/5a9d082b51b06c5b0467b18a8175cad183a4e89c/src/jax_quantity/_register_primitives.py#L887-L900, which appears to be correct, but must subtly misunderstand how to override the primitive when jax passes Tracers through the operations.

patrick-kidger commented 10 months ago

Hmm, it's not clear to me if you have a question or if you've resolved it?


FWIW, I will comment that you are doing something that I'm still not completely happy with Quax's behaviour for, which is returning custom array-ish values from primitive rules. This is fine when working with a function (Quantity, DenseArrayValue) -> Quantity, for which we know the exact types.

But it's not really clear how it should be defined for something like (Quantity, ArrayValue) -> Quantity, where the ArrayValue could be any concrete type at all. Let me explain a little bit.

Problem 1: handling nested values

For example maybe the ArrayValue happens to be a hypothetical SparseArrayValue -- and in this case we may like the implementation to redispatch (using Quax again!) to a (DenseArrayValue, SparseArrayValue) -> SparseArrayValue operation, and then wrap the return back into a Quantity. So that the overall return value is Quantity(SparseArrayValue(...), ...).

That's now pretty weird: when you wrote Quantity you appear to have assumed that it's wrapping specifically and only a raw JAX array:

https://github.com/nstarman/jax-quantity/blob/5a9d082b51b06c5b0467b18a8175cad183a4e89c/src/jax_quantity/_core.py#L26

not a SparseArrayValue. This could cause all kinds of bugs later down the line.

Problem 2: easy bugs in implementations

As another example of the problems this causes, consider this line:

https://github.com/nstarman/jax-quantity/blob/5a9d082b51b06c5b0467b18a8175cad183a4e89c/src/jax_quantity/_register_primitives.py#L920

in which you've written lax.ne(x.value, y) -- but y is an ArrayValue, for which this operation may raise an error. You probably meant quaxify(lax.ne)(x.value, y). That's a bit annoying, having to put a quaxify on literally every JAX operation inside a Quax rule.


Right now I'm still thinking about a plan to fix this. I'm not sure exactly what that will be yet, though! So I'd welcome any thoughts on what might be a nice way to do this. The end goal will probably be to end up with something Julia-like, where you can just nest array-ish values freely without worrying too much; possibly this can be accomplished by always putting Quax at the bottom of the interpreter stack. Details very much TBD; you have been warned :)

nstarman commented 10 months ago

Hmm, it's not clear to me if you have a question or if you've resolved it?

you're right either way 😆. I've half-solved the problem. Functions that don't internally construct and operate on a Quantity work great. Functions that do still raise an error.

FWIW, I will comment that you are doing something that I'm still not completely happy with Quax's behaviour for, which is returning custom array-ish values from primitive rules. This is fine when working with a function (Quantity, DenseArrayValue) -> Quantity, for which we know the exact types.

Thanks! I'll change my annotations to DenseArrayValue for now. I would like to be able to work on custom array-ish values, but Jax arrays and Quantity are good for now.

nstarman commented 10 months ago

@patrick-kidger, I found the root of the problem and have a question about how to best resolve the issue.

Consider this MWE (a simplified form of Quantity)

from typing import Self

import jax
import quax
from jax import lax

class MyArray(quax.ArrayValue):
    value: jax.Array
    unit: str

    @property
    def shape(self) -> tuple[int, ...]:
        """Shape of the array."""
        return self.value.shape

    def materialise(self) -> None:
        raise RuntimeError("Refusing to materialise `MyArray`.")

    def aval(self) -> jax.core.ShapedArray:
        return jax.core.get_aval(self.value)

@quax.register(lax.mul_p)
def _(x: MyArray, y: MyArray) -> MyArray:
    unit = f"{x.unit}*{y.unit}"
    return MyArray(lax.mul(x.value, y.value), unit=unit)

If I define a function func then everything works as expected

def func(q: MyArray) -> MyArray:
    c = MyArray(2.0, unit="m2")
    jax.debug.print("mutiplying {} * {}", type(c), type(q))
    return c * q

x = MyArray(1.5, unit="m")

out = func(x)
> mutiplying <class '__main__.MyArray'> * <class '__main__.MyArray'>
> MyArray(value=f64[], unit='m2*m')

However if this function is quaxify'd then it will fail

quaxfunc = quax.quaxify(func)
out = quaxfunc(x)
> mutiplying <class '__main__.MyArray'> * <class 'quax._core._QuaxTracer'>

The issue is that x in quaxfunc is wrapped by the QuaxTracer (x.array.value is MyArray) but c remains a MyArray.
I think this is the issue underlying quaxify(grad(func)).

I can hack around this by defining

@quax.register(lax.mul_p)
def _(x: MyArray, y: quax.DenseArrayValue) -> MyArray:
    assert isinstance(y.array, quax._core._QuaxTracer)
    assert isinstance(y.array.value, MyArray)

    unit = f"{x.unit}*{y.array.value.unit}"
    return MyArray(lax.mul(x.value, y.array.value.value), unit=unit)

I don't see a comparable scenario in quax atm. Thanks for the help!

patrick-kidger commented 9 months ago

This looks expected. When you wrap func into quaxify(func), then you are turning a function that acts on arrays into a function that acts on quax.ArrayValues. To better understand this, try printing out the jaxpr for the original function:

jaxpr = jax.make_jaxpr(func)(x)
print(jaxpr)

this defines a sequence of primitive operations, turning input into output. Then the Quaxified version -- quaxify(func) -- will iterate through every equation in the jaxpr, looking for a registered rule for that primitive, for the ArrayValues used as inputs.

So in this case, the quaxify means that the input to func is being thought of as a regular array rather than as a MyArray, and that's why you hit the (:MyArray, :DenseArrayValue) rule.

Here, what you probably want to do is just not apply the quaxify wrapper, as what you have already works without it.

(FWIW this is a complexity I'm hoping will be tidied up when I rewrite things a bit, as per my previous comment.)

nstarman commented 9 months ago

Here, what you probably want to do is just not apply the quaxify wrapper, as what you have already works without it

Unfortunately since that function was a MWE of quaxify(grad(func)) applying the quaxify wrapper is a necessity.

So in this case, the quaxify means that the input to func is being thought of as a regular array rather than as a MyArray, and that's why you hit the (:MyArray, :DenseArrayValue) rule.

Thanks for the diagnosis! what is the proper way to get the MyArray from inside the DenseArrayValue? I tried x.array.value (requiring adding if/else statement logic to the overrides) and this works for the quaxify decorator, but grad then changes x.array to a JVPTracer. My hacking around is in https://github.com/GalacticDynamics/jax-quantity/pull/4.

import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify

jax.config.update("jax_enable_x64", True)

x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")

def func(q: Quantity) -> Quantity:
    return Quantity(2.0, unit="m2") * q

out = func(q[0])
print(out.value, out.unit)
> (Array(2., dtype=float64), Unit("m3"))

out = quaxify(jax.grad(func))(q[2])
print(out.value, out.unit)
> TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m3")).

I think the way I deconstructed JVPTracer is leading to this problem. The multiplication is coming in as mul(:Quantity, :DenseArrayValue). I pass this to mul(:Quantity, :Quantity) by taking y.array.primal.value on the latter argument.

patrick-kidger commented 9 months ago

Okay, returning to this issue! It has not been forgotten about...

As above, this discussion has prompted me to do some harding thinking on the design choices for Quax. The just-released v0.0.3 release should hopefully straighten things out a bit. It's very much a breaking release (!), but if it seems to work then hopefully we can standardise on this, and start building libraries on top of Quax in earnest.

With respect to topics discussed in this issue:

nstarman commented 8 months ago

Thanks! grad works great when all inputs are format arguments.