Closed nstarman closed 9 months ago
So this is intentional but definitely questionable.
The reason for this is this rule for Zero
, which will bind against any ArrayValue
(which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ish Zero
as an input to the rule.
I think changing this may be impossible without also having jnp.zeros(...)
also return an Array. Both use identical primitive binds inside JAX: loosely speaking they're implemented as def zeros(shape, dtype): return broadcast(0, shape, dtype)
and def zeros_like(array): return broadcast(0, array.shape, array.dtype)
.
I'm not completely sure how this might be tackled; I'd welcome any thoughts.
Thanks for the fast response.
The reason for this is this rule for
Zero
, which will bind against anyArrayValue
(which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ishZero
as an input to the rule.
I had registered a specific dispatch for the primitive lax.broadcast_in_dim_p
. However, I don't think this is ever called with MyArray
, as the Zero
is instantiated, then fed into the broadcast
.
I'm not completely sure how this might be tackled; I'd welcome any thoughts.
Short of monkey-patching jax
, IDK as well. I think I might just use plum
on a custom zeros_like
function to dispatch to MyArray
(and Zero
). That should work, if a little less elegantly than quax's internal use of plum.
as the
Zero
is instantiated, then fed into thebroadcast
.
Actually, it's a JAX array that's instantiated! But indeed, it's not a custom MyArray
either way.
FWIW, I'm contemplating simply removing that broadcast dispatch rule from Quax. It's always been a bit magic that we produce a Zero
without having a Zero
input. It's also not totally reliable: something like quax.quaxify(jnp.zeros_like)(jnp.array(0))
doesn't trigger it, as that doesn't involve a broadcast. A zeros
/zero_like
will then unconditionally produce a normal JAX array.
I'm not sure how much that helps you of course, but it's something.
Still very much a work in progress, but check out https://github.com/GalacticDynamics/array-api-jax-compat/, where I'm leveraging quax
to make a bridge to the Array API that also works with JAX array-ish objects.
Oh, this looks neat! Can you tell me a bit more about this?
I'm noticing a simliarity to jax.experimental.array_api
, I assume you're buliding off of that as well?
Thanks, I was actually not aware of jax.experimental.array_api
but that will make my life significantly easier!
Each function in https://github.com/GalacticDynamics/array-api-jax-compat/
will then be a quax
-ified / plum.dispatcher
wrapper around jax.experimental.array_api
+ miscellaneous quax
-ified functions like jacfwd
and grad
.
The goal is to support Astropy-like Quantities in JAX https://github.com/GalacticDynamics/jax-quantity. In that repo I've gotten most of the Astropy -> quax -> jax bridges completed.
array-api-jax-compat is now mostly a quax
wrapper around jax.experimental.array_api
.
Awesome!
I think this is now resolved!
Hi! I'm running into a few issues with
quax
+jnp.zeros_like
. The built-inquax.zeros.Zeros
appears to leak into thejnp.zeros_like
and its related functions.I think
quaxify(jnp.zeros_like)(jax.Array)
should output ajax.Array
.More importantly (to me), when I make a custom
quax.ArrayValue
subclass it doesn't override the behavior ofZero
with these functions.