Open patrick-kidger opened 2 years ago
I think the moral of the story here is we need every API entrypoint in JAX to call jnp.asarray()
on array-like inputs.
import jax
import jax.numpy as jnp
import typing
class MyArray:
def __jax_array__(self):
return jnp.array([[1.]])
jnp.asarray(MyArray())
# TypeError: Value '<__main__.MyArray object at 0x7f67b007c4c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.
unfortunately. (With our without the disable_jit
.)
Nor in the PyTree case, in which it instead fails silently:
import jax
import jax.numpy as jnp
import typing
class MyArray(typing.NamedTuple):
def __jax_array__(self):
return jnp.array([[1.]])
out = jnp.asarray(MyArray())
print(repr(out))
# DeviceArray([], dtype=float32)
Well, in that case, we need __jax_array__
to be respected by jnp.array()
.
The real issue is that __jax_array__
is only partially supported throughout the JAX package with very little test coverage, and making it fully supported and tested will take a lot of work (for what its worth, this is the kind of issue I had in mind when I initially advocated against adding it)
I'll assign the issue to @mattjj, who didn't anticipate any support burden 😀
I'll assign the issue to @mattjj, who didn't anticipate any support burden 😀
Haha!
FWIW, as a developer I would also not have added __jax_array__
for exactly this reason. But as an end user I'm actually really enjoying having it around! For example over in https://github.com/patrick-kidger/equinox/issues/53 we're discussing using it to implement spectral norm as a "parameterised weight" -- much like how PyTorch does via torch.nn.utils.parameterize
.
@jakevdp that was a good one. Though the reason the support burden was limited is we didn't tell anyone about it or promise anything about it!
(Also I haven't paged in this whole thread yet but I'd be happy to delete __jax_array__
.)
IIUC this issue only arises with an object that is both a pytree and has a __jax_array__
defined. I think we should say that in that case the pytree semantics take precedence, right? (In general that situation seems ambiguous as to whether something should be considered a container or an arraylike leaf. So we could also just say it's undefined behavior.)
(Also I haven't paged in this whole thread yet but I'd be happy to delete
__jax_array__
.)
:( Please don't!
IIUC this issue only arises with an object that is both a pytree and has a
__jax_array__
defined. I think we should say that in that case the pytree semantics take precedence, right?
Nope -- it's just that __jax_array__
isn't being respected in some places.
I'm actually finding that it's very natural to have something that is both a PyTree and a JAX array:
class Buffer(typing.NamedTuple):
value: jnp.ndarray
def __jax_array__(self):
return lax.stop_gradient(self.value)
class SpectralNorm(typing.NamedTuple):
weight: jnp.ndarray
u: jnp.ndarray
v: jnp.ndarray
def __jax_array__(self):
u, v = power_iteration(self.u, self.v, self.weight)
σ = jnp.einsum("i,ij,j->", u, self.weight, v)
return self.weight / σ
class Symmetric(typing.NamedTuple):
value: jnp.ndarray
def __jax_array__(self):
return self.value + self.value.T
etc. etc.
In each case it's a PyTree wrt JIT etc, as we need to transparently see the wrapped value. It's a __jax_array__
with respect to jnp
etc. operations, which sees the transformed value.
If jnp.whatever
happened to be generic over PyTrees then there would indeed be a problem. But they're not! They only accept JAX arrays. Meanwhile jax.jit
only accepts PyTrees etc. and doesn't/shouldn't care about __jax_array__
, I think. (So IMO it's wrong that jax.jit
currently handles __jax_array__
in some places.)
AFAICT, PyTrees and __jax_array__
are pretty much orthogonal. Off the top of my head, there isn't an API which needs to make a choice between the two interpretations.
Very interesting points @patrick-kidger – I started writing a comment pushing back against this idea, but while writing it I ended up realizing you're right 😁 . This idea of the orthogonality of tree flattening and __jax_array__
is pretty compelling. That said, if there is a class that is not a pytree and defines __jax_array__
, ISTM that __jax_array__
should be called at the JIT boundary. What do you think?
:D
That said, if there is a class that is not a pytree and defines
__jax_array__
, ISTM that__jax_array__
should be called at the JIT boundary. What do you think?
I don't think so. (a) For consistency between pytree/non-pytree; (b) for consistency between jit/non-jit. For example the following would work without JIT but would fail with JIT.
class M:
def __jax_array__(self):
return jnp.array(1.)
def foo(self):
return jnp.array(2)
# @jax.jit
def call(m):
return m + m.foo()
call(M())
Which is admittedly a bit contrived -- I don't think there's that many classes being passed into JIT that aren't also pytrees -- but I don't see a compelling reason to perform the __jax_array__
conversion across JIT boundaries either.
I see what you're saying there, but I think it's probably counter to the original intent of the __jax_array__
mechanism, which (as I understand it) imagined it as a way of making an arbitrary object be interpreted by JAX as an array.
Right! Just, interpreted by jnp
and friends, rather than interpreted by jax.jit
. (What about jax.vmap
, jax.grad
etc?)
Perhaps our mental models as to what jax.jit
should do are slightly different. (I see the argument the other way.) Anyway, wrt this latter point it's not a strong feeling on my part.
I disagree about orthogonality. For example, in jax.lax.scan(f, None, Symmetric(jnp.ones((1, 2))))
, what's the length of the scanned-over axis? (In words, for APIs which accept pytrees-of-arrays like scan
, an object which is both registered as a non-leaf pytree and has a __jax_array__
method can either be interpreted as a non-leaf pytree or as an array, and a decision must be made.)
I think we've got two discussions here, one narrowly about jax.linalg.svd
and co (i.e. the original issue), and the other about how __jax_array__
should behave more generally, e.g. whether jit
boundaries should call __jax_array__
.
For jax.linalg.svd
and co, there may be a quick way to extend __jax_array__
support to those functions even with disable_jit
(and/or with pytree registration). But doing that is not a high priority enhancement. (I call it an enhancement and not a bug because #4725 made no promises about the functionality of __jax_array__
beyond whatever was handled in that PR.)
For the latter discussion about __jax_array__
's behavior e.g. at jit
boundaries, while the behavior Patrick wants could be reasonable, it is indeed counter to the original intent of the __jax_array__
mechanism, which indeed was quite narrow.
I think we should
__jax_array__
API, but alsojnp
functions' behavior as well as transformation behaviors (we have some work on this already), and__jax_array__
until item 2 lands.WDYT?
Hmm, good point about lax.scan
. In a few cases a choice does have to be made.
Anyway, everything you say sounds reasonable. I enjoy using __jax_array__
; any equivalent/more-general functionality also sounds great.
(In the end I think these kinds of proposals all end up being special cases of multiple dispatch.)
Another problematic point might be jax.scipy.sparse.linalg.cg
& co: imagine a PyTree-object that behaves as a matrix, but it is not one. And imagine that we can convert it to a jax.array.
What should the api do in that case? It should do the iterative solve using the lazy version, not the dense operator, which is there just for ease-of-use of users.
@pytree
class LazyMatMul:
a
b
def __call__(self, v):
return self.b@(self.a@v)
def __matmul__(self, v):
return self.b@(self.a@v)
def __jax_array__(self):
return self.b@self.a
ab = LazyMatMul(jnp.ones((3,3)), jnp.ones(3,3))
jax.scipy.sparse.linalg.cg(ab, jnp.ones(3))
@mattjj In favour of not removing __jax_array__
: it allows to write code that is backend-agnostic using NEP47.
maybe API can follow some priorities, e.g. method-specific-type(e.g. callable
for cg
) > pytree
> __jax_array__
.
BUT.
What about a pytree of (pytree and __jax_array__)?
Hello @patrick-kidger and @mattjj, sorry to unearth this but I had the same problem with part of the JAX API. I want to make a PyTree which is a valid JAX array to attach metadata to arrays.
I noticed that jax._src.lax.lax.asarray
does not comply with __jax_array__
, which is easy to fix.
In addition, jax.core.Primitive.bind
checks if arguments are jax.core.valid_jaxtype
, but does not call __jax_array__()
on arguments that need it. This is problematic if the result of __jax_array__()
is a Tracer
. Once again it is easy to fix as well.
With these two patches, it seems that my PyArray wrapper is compatible with a good part of jax.numpy
and jax.lax
. It worked with everything I tried actually.
__jax_array__
is undocumented and (mostly) untested. We have no intent to support __jax_array__
universally in the JAX package, and I would suggest not writing code that relies on it.
I want to make a PyTree which is a valid JAX array to attach metadata to arrays.
Can you say more about this use-case? There may be better approaches to doing what you have in mind.
Hello @jakevdp
Can you say more about this use-case? There may be better approaches to doing what you have in mind.
I am writing a small JAX library (Inox) in which modules are PyTrees (similar to Equinox) whose leaves are the internal arrays. However, I need a way to distinguish between arrays that are constants, parameters, running statistics, ... for updates. My approach is to wrap arrays into a shallow PyTree with static metadata (what I call a PyArray
). This is similar to the way PyTorch and Flax indicate parameters.
With the __jax_array__
interface, I can make PyArray
valid JAX arrays, meaning that users don't have to unpack PyArray
instances to use them as arrays (same as torch.nn.Parameter
). Note that I don't care about propagating the metadata.
I also tried to make PyArray
a Tracer
with the EvalTrace
to not propagate the trace, but did not succeed.
__jax_array__
is undocumented
Indeed, it was fairly painful to find a solution. But I think the API has potential!
We don't have any support for this kind of implicit dispatch to user-defined types. It's something we've discussed, but we haven't yet found use-cases that warrant the kind of investment it would require. I'm fairly certain though that if we did this, it would not rely on __jax_array__
for dispatch.
For your use-case, I would suggest explicitly unwrapping your wrapped arrays before passing them to jax
APIs.
An alternative would be to enable metadata in arrays directly (names, roles, ...). In some sense that is the concept of Tracer
.
An alternative would be to enable metadata in arrays directly (names, roles, ...). In some sense that is the concept of
Tracer
.
This can absolutely be done, actually -- take a look at Quax. this allows for tracing "array-ish" objects, and then doing multiple dispatch on them at the level of a primitive bind. This looks something like:
class LoraArray(quax.ArrayValue):
...
@quax.register(lax.dot_general_p)
def _(x: LoraArray, y: Array, ...);
... # implement LoRA matmuls for this new array-ish type.
quax.quaxify(some_function)(LoraArray(...), jnp.array(...), ...)
with the quaxify
transforming the function (just as jax.jit
etc. do), and during the tracing multiple dispatch rules are looked up.
FWIW right now this is pretty experimental, but it may be useful to you as a starting point.
That looks like what I need! Although, I have troubles understanding some parts of the Tracer
and Trace
interface which is not very well documented. For example shouldn't _QuaxTrace
have a main
attribute?
Also, instead of using a decorator, would it be possible to add a QuaxTrace
in the trace_stack
without ever removing it? Like the EvalTrace
at the bottom.
That looks like what I need! Although, I have troubles understanding some parts of the
Tracer
andTrace
interface which is not very well documented. For example shouldn't_QuaxTrace
have amain
attribute?
It does, via jax.core.Trace
itself:
Also, instead of using a decorator, would it be possible to add a
QuaxTrace
in thetrace_stack
without ever removing it? Like theEvalTrace
at the bottom.
Hmm, that's an interesting idea! And one that I really like, actually.
I'm not competely sure -- it might end up having to touch JAX internals? (E.g. it might end up being morally equivalent to monkey-patching EvalTrace
, which wouldn't be great.)
Okay, let's do some inside baseball -- let me just say a bit more about your idea of putting a QuaxTrace
at the bottom of the trace_stack
.
A big part of why I like that idea so much is that right now it's a fairly complicated business to write a dispatch rule. Using LoRA as an example, we actually need our Quax rule to call back into a special version of quaxify
:
in order to redispatch on the type of rhs
here:
or to handle the possibility that lhs.{w,a,b}
, or are themselves array-ish values. This ends up making it a fairly tricky business to write rules correctly.
This also ties in with a personal (private) project I've got, reimplementing + varying some of JAX's ideas. This does something very similar to your suggestion -- it has multiple dispatch as a first-class citizen of the bottom-of-stack evaluation trace, and in doing so actually manages to handle stuff like abstract evaluation (and maybe also JIT'ing?) as a special case of this single notion of evaluation. And this is pretty neat!
Besides the above, I can see that you're also interested in this because it removes the need for the quaxify
wrapper itself, which is a plus for usability.
If you're interested in playing with this idea then I'd love to know what you find. And I'm definitely open to changing Quax in this way if there's a better design choice. Maybe we can do something interesting with this!
I will probably try to add multiple dispatch to the autodidax tutorial instead of the actual JAX API. I was also curious about caching dispatch results (e.g. in the case of reparameterizations).
@jakevdp Would adding a permanent DispatchTrace
in trace_satck
or modifying EvalTrace
be a viable solution to enable multiple dispatch in JAX?
We would probably not make any change of that scope without a more comprehensive design process: i.e. writing out the goals and non-goals, evaluating the space of possible approaches and their advantages and disadvantages, giving stakeholders time to provide feedback, and only then starting to write the code.
Hi @patrick-kidger
This issue appears to have been resolved in JAX version 0.3.22 itself with the PR #12693 which expanded the support for __jax_array__
in jnp.array
. I tested the mentioned code with various versions of JAX after 0.3.22 and it now works as expected. The following code does not produce any error.
import jax
import jax.numpy as jnp
class MyArray:
def __jax_array__(self):
return jnp.array([[1.]])
with jax.disable_jit():
jnp.linalg.svd(MyArray()) # works fine without any error
and
import jax
import jax.numpy as jnp
from typing import NamedTuple
class MyArray(NamedTuple):
def __jax_array__(self):
return jnp.array([[1.]])
jnp.linalg.svd(MyArray())
produces:
Array([[1.]], dtype=float32)
Please find the gist for reference.
Thank you.
This particular issue is resolved, but the larger issue still remains: we don't comprehensively support __jax_array__
across the package, nor have we ever intended to. This is a non-public API that is undocumented and for the most part untested, and we make no guarantees about its stability or comprehensiveness. Downstream packages should not depend on it.
A couple things going on here. First of all, the following is an example of
jnp.linalg.svd
failing to respect__jax_array__
.Remove the
disable_jit
and this works.The reason it works without
disable_jit
is thatjnp.linalg.svd
and friends all havejax.jit
wrappers, which is what spots the__jax_array__
and handles things appropriately... unless the JAX arraylike is also a PyTree, in which case they don't. So this also fails (with a different error message this time):So whilst it takes either a
disable_jit
or a PyTree to actually trigger it, I think the fundamental issue here is thatjnp.linalg.svd
and friends do not check for JAX arraylikes.