patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
100 stars 2 forks source link

Support for `zeros_like` and related #2

Closed nstarman closed 7 months ago

nstarman commented 9 months ago

Hi! I'm running into a few issues with quax + jnp.zeros_like. The built-in quax.zeros.Zeros appears to leak into the jnp.zeros_like and its related functions.

>>> import jax.numpy as jnp
>>> from quax import quaxify
>>> x = jnp.array([1, 2, 3], dtype=jnp.float64)
>>> quaxify(jnp.zeros_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))
>>> quaxify(jnp.empty_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))

I think quaxify(jnp.zeros_like)(jax.Array) should output a jax.Array.

More importantly (to me), when I make a custom quax.ArrayValue subclass it doesn't override the behavior of Zero with these functions.

>>> class MyArray(quax.ArrayValue): ...
>>> ... [override all related primitives]

>>> y = MyArray(x)
>>> quaxify(jnp.zeros_like)(y)
Zero(_shape=(3,), _dtype=dtype('float64'))
patrick-kidger commented 9 months ago

So this is intentional but definitely questionable.

The reason for this is this rule for Zero, which will bind against any ArrayValue (which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ish Zero as an input to the rule.

I think changing this may be impossible without also having jnp.zeros(...) also return an Array. Both use identical primitive binds inside JAX: loosely speaking they're implemented as def zeros(shape, dtype): return broadcast(0, shape, dtype) and def zeros_like(array): return broadcast(0, array.shape, array.dtype).

I'm not completely sure how this might be tackled; I'd welcome any thoughts.

nstarman commented 9 months ago

Thanks for the fast response.

The reason for this is this rule for Zero, which will bind against any ArrayValue (which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ish Zero as an input to the rule.

I had registered a specific dispatch for the primitive lax.broadcast_in_dim_p. However, I don't think this is ever called with MyArray, as the Zero is instantiated, then fed into the broadcast.

I'm not completely sure how this might be tackled; I'd welcome any thoughts.

Short of monkey-patching jax, IDK as well. I think I might just use plum on a custom zeros_like function to dispatch to MyArray (and Zero). That should work, if a little less elegantly than quax's internal use of plum.

patrick-kidger commented 9 months ago

as the Zero is instantiated, then fed into the broadcast.

Actually, it's a JAX array that's instantiated! But indeed, it's not a custom MyArray either way.

FWIW, I'm contemplating simply removing that broadcast dispatch rule from Quax. It's always been a bit magic that we produce a Zero without having a Zero input. It's also not totally reliable: something like quax.quaxify(jnp.zeros_like)(jnp.array(0)) doesn't trigger it, as that doesn't involve a broadcast. A zeros/zero_like will then unconditionally produce a normal JAX array.

I'm not sure how much that helps you of course, but it's something.

nstarman commented 9 months ago

Still very much a work in progress, but check out https://github.com/GalacticDynamics/array-api-jax-compat/, where I'm leveraging quax to make a bridge to the Array API that also works with JAX array-ish objects.

patrick-kidger commented 9 months ago

Oh, this looks neat! Can you tell me a bit more about this?

I'm noticing a simliarity to jax.experimental.array_api, I assume you're buliding off of that as well?

nstarman commented 9 months ago

Thanks, I was actually not aware of jax.experimental.array_api but that will make my life significantly easier! Each function in https://github.com/GalacticDynamics/array-api-jax-compat/ will then be a quax-ified / plum.dispatcher wrapper around jax.experimental.array_api + miscellaneous quax-ified functions like jacfwd and grad.

The goal is to support Astropy-like Quantities in JAX https://github.com/GalacticDynamics/jax-quantity. In that repo I've gotten most of the Astropy -> quax -> jax bridges completed.

nstarman commented 8 months ago

array-api-jax-compat is now mostly a quax wrapper around jax.experimental.array_api.

patrick-kidger commented 8 months ago

Awesome!

nstarman commented 7 months ago

I think this is now resolved!