Open willwhitney opened 4 years ago
Based on the notebook that you've linked, it seems that you might be aiming for the following:
In [1]: from jax import nn, numpy as jnp
In [2]: nn.one_hot(jnp.array([2, 3, 4]), 5)
Out[2]:
DeviceArray([[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]], dtype=float32)
Is that correct? That is, nn.one_hot
already supports a batch of labels, for a fixed class count, so no need to vmap
.
Hypothetically, if one_hot
didn't already map over its first operand, we could extend it with vmap
so long as we use in_axes
to map only over the first argument:
In [1]: from jax import nn, vmap, numpy as jnp
In [2]: vmap(nn.one_hot, in_axes=(0, None))(jnp.array([2, 3, 4]), 5)
Out[2]:
DeviceArray([[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]], dtype=float32)
The None
corresponding to the second argument means "do not map."
When we vmap(f)
, every mapped argument is treated as abstract. That leads to concretization errors when f
passes a mapped array to a function that expects a concrete/static operand. We should try to improve our error messages here, more generally.
Sorry, my message was not sufficiently clear. one_hot
already supports batched inputs, however it can't be used inside a function which is vmapped (due to the error shown in this simple example).
For a simple contrived example:
def f(seed):
n = random.randint(random.PRNGKey(seed), shape=(1,), minval=0, maxval=10)
m = 2 * n
return nn.one_hot(n, m)
f_batch = jax.vmap(f)
In [4]: f(jnp.array(0))
Out[4]: DeviceArray([[0., 0., 1., 0.]], dtype=float32)
In [5]: f_batch(jnp.array([0, 1, 2, 3]))
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-5-b01772423584> in <module>
----> 1 f_batch(jnp.array([0, 1, 2, 3]))
[...]
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[1])>with<BatchTrace(level=0/0)>
with val = DeviceArray([[ 4],
[16],
[16],
[ 8]], dtype=int32)
batch_dim = 0.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.
Note that while this example is silly, I ran into this while building something real.
A couple of points at the meta level:
static_argnums
flag like jit does. This is probably only possible for args which have in_axes=None
, but that seems fine.static_argnums
, it will work with vmap. However, since vmap "opens the box" of that abstraction, it fails. In my mind jit with static_argnums
should work exactly like a jax.partial
for any particular value of the argument. Making transformations that are applied later respect static_argnums
would improve the mental model quite a bit.Hey Will, great to hear from you!
in_axes=None
works exactly like jit
's static_argnums
already. You can also use lexical closure to get the same effect.
Actually, one_hot
can be used inside a function that is vmapped. The problem with the first example is just non-rectangularity. In terms of what values are computed, we can define vmap
's behavior as:
vmap(f)(xs) == jnp.stack([f(x) for x in xs])
or, for two-argument functions, as
vmap(f)(xs, ys) == jnp.stack([f(x, y) for x, y in zip(xs, ys)])
Let's try that with the first example involving one_hot
:
In [1]: import jax.numpy as jnp
In [2]: from jax.nn import one_hot
In [3]: idxs = jnp.array([0, 1, 2])
In [4]: ns = jnp.array([3, 4, 5])
In [5]: jnp.stack([one_hot(idx, n) for idx, n in zip(idxs, ns)])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-5-934b96920859> in <module>
----> 1 jnp.stack([one_hot(idx, n) for idx, n in zip(idxs, ns)])
~/packages/jax/jax/numpy/lax_numpy.py in stack(arrays, axis)
2004 for a in arrays:
2005 if shape(a) != shape0:
-> 2006 raise ValueError("All input arrays must have the same shape.")
2007 new_arrays.append(expand_dims(a, axis))
2008 return concatenate(new_arrays, axis=axis)
ValueError: All input arrays must have the same shape.
IMO this shows this isn't a vmap issue, or to do with nested functions or anything like that. Instead it's just that the result here would have to be non-rectangular, and so couldn't even be an ndarray-like result.
In the second example, JAX isn't raising the right error message, but the issue is the same: vmap can only produce rectangular results, and so trying to vmap over arguments that determine output shapes won't work.
To answer your questions directly, for clarity at the risk of repetition:
It seems like every function in JAX should support the core JAX transformations.
I think things are already working as intended here, modulo the error messages that could use improvement. Different transformations have different levels of constraints, e.g. vmap has rectangularity constraints constraining whether you can vmap over arguments that affect the shape of the output, yet those constraints aren't relevant for grad.
I've had a need to make vmap work with static arguments on multiple occasions, so maybe it should have a static_argnums flag like jit does. This is probably only possible for args which have in_axes=None, but that seems fine.
I think in_axes=None
already does what you want; they're treated as statically as possible. Here's an example:
In [1]: from jax import vmap
In [2]: def app(f, x):
...: return f(x)
...:
In [3]: import jax.numpy as jnp
In [4]: vmap(app, in_axes=(None, 0))(lambda x: x ** 2, jnp.arange(3))
Out[4]: DeviceArray([0, 1, 4], dtype=int32)
Making transformations that are applied later respect static_argnums would improve the mental model quite a bit.
I think we can make the error messages here clearer, but I suggest thinking about static_argnums (and in_axes=None) differently. Each transformation can impose requirements on the function to be transformed. Those requirements are imposed by raising the abstraction level of the inputs; since a transformation can impose (but not remove) constraints, that means we might raise (but never lower) the abstraction level of arguments. By using static_argnums, you're telling jit not to raise the abstraction level of particular inputs. But that doesn't help if an outer vmap has already raised them. That is, the outer vmap is imposing constraints regardless of whether you use static_argnums with jit, or use jit at all.
We think a better error we can raise here would be to check if arguments passed as static to a jitted function are already abstracted to the shaped level, and if so either warn or error because it's probably not intended. Ideally we could even explain why those inputs got abstracted (e.g. mention that there was a particular vmap on a particular line that caused the issue).
WDYT?
Thanks for the thorough and thought-provoking answer! I understand what's going on under the hood a bit better now, and I agree that everything is working correctly under its own precepts. Better error messages will go a long way towards pointing new users like me in the right direction.
(Though it seems like it might be nicer to get an error at runtime when a vmap output would not be rectangular. This would allow code like jax.vmap(nn.one_hot)(jnp.array([2,3,4]), jnp.array([5,5,5]))
, and more importantly vmapping arbitrary other functions whose output sizes depend on their inputs, to work when the sizes allow. Of course I don't know if this is possible.)
At a high level, I think what I've been struggling with is finding the right mental model for reasoning about arbitrary JAX code. My default, at least when writing (nominally) Python code, is to believe that pure functions are perfectly modeled by black boxes such that the implementation of a function is irrelevant given its signature. As your discussion here nicely illustrates, this is not the right metaphor for JAX transformations.
It feels like I'm missing a fundamental concept; the constraints imposed by JAX transformations don't fit neatly into Python functional programming. I suppose really the transformations like vmap are closer to Lisp macros than to e.g. Python or Haskell, and the complete power of a macro to rewrite a function makes them hard to reason about.
I wonder if a different metaphor for transformations which makes what's going on under the hood would make them a bit more intuitive. Maybe, as you describe, the requirements that the transformation imposes would be a better way of thinking about things, and we could make those requirements first-class members of the API?
I don't mean to sound overly critical; if I am sometimes frustrated with JAX, it's only because it's so wonderfully powerful that I want to use it for everything. JAX is absolutely amazing, unlocking the ability to do exotic things that I would never consider doing with another library. Over a month or two of full-time work in JAX I'm building more of an intuition for how to approach things and what's possible, and when things go well the experience of "I can do that?!" is incredible.
In fact, that experience is so good that I want everyone to have it. But doing those more exotic, off-the-beaten-path things can be fraught with peril. JAX provides some very powerful lego blocks that look like they snap together in any order, but in many configurations they explode. Part of making JAX easier to learn is about tutorials and error messages, but I think part of the story may be using what we've learned so far to design metaphors that make it more obvious how to connect things together.
I think a good start for a mental model of (jittable) JAX functions (or, ~equivalently, the set of functions XLA can compile), is:
The second bullet is actually a special case of a more general property, namely that jittable JAX functions can't dynamically allocate memory after they start executing; having shapes that depend on values would force such a disallowed dynamic allocation.
It doesn't currently seem to be possible to use
vmap
onnn.one_hot
. Out of the box it fails due to concretization errors. I would expect to be able to fix it by jitting withstatic_argnums=(1,)
to make the dimension of the returned vector constant, but this just leads to a different concretization error. Notebook: https://colab.research.google.com/drive/1cVoS35UKXIj2shqDNdtkeYjlDsEMXfwXA workaround is to do something like
one_hot_10 = jax.partial(nn.one_hot, k=10)
, but this is a bit clunky.