Closed NeilGirdhar closed 1 year ago
Current workaround:
print(key_array.impl) # Get type.
print(key_array.unsafe_raw_array())
# and then...
raw_array = jnp.array((2634740717, 3214329440), dtype=jnp.uint32)
from jax._src.prng import PRNGKeyArray, threefry_prng_impl
key_array = PRNGKeyArray(threefry_prng_impl, raw_array) # Create using a type and value.
If you're importing _src.prng
for workarounds, a better workaround may be:
jax._src.prng.random_wrap(raw_array, impl=impl)
In particular, this works under jit
.
I'll try to write back soon about the broader questions up top. Just wanted to get this to you in the meantime. As always, thanks for putting thought into these things and for the discussion/feedback, Neil!
Coming back to this, the "workaround" might actually be an answer to part of the question. @NeilGirdhar – regardless of the idiom by which we set up PRNG implementations, my original thought was to expose random_wrap
and random_unwrap
(perhaps by a different name) as public API functions. Would that suffice for your needs as more than a workaround?
I'm also happy to separately review improvements to the internal idiom we use for PRNG implementations (e.g. a switch to inheritance), assuming it's a separate concern, and especially if it cleans things up or makes them easier!
The reason that I propose to have random_{,un}wrap
be the way only in/out of key arrays follows my previous comment: these functions bind corresponding array-casting primitives, and in particular calling them is invariant under jit
(and staging more generally). Constructing or unwrapping a Python PRNGKeyArray
(or derivative thereof) is merely the corresponding impl rule for these primitives.
Does this seem right?
Sorry for the delay. I wanted to invest some time in understanding random_wrap
. I'm afraid I don't understand what JIT invariance means. My understanding of Jax internals still has a long way to go. I don't want to use up your valuable time, but why is it that simple construction doesn't work with the JIT? In other words, why do you need to do random_wrap(array, impl)
instead of PRNGKeyArray(impl, array)
?
Would that suffice for your needs as more than a workaround?
That means you'll be exposing the PRNGImpl
classes like threefry_prng_impl
? Yes, that definitely works.
I'm also happy to separately review improvements to the internal idiom we use for PRNG implementations (e.g. a switch to inheritance), assuming it's a separate concern, and especially if it cleans things up or makes them easier!
Yes, it's a separate concern! And I love your commitment to finding the best design.
This seems to be related:
from jax import enable_custom_prng, vjp
from jax.random import PRNGKey
with enable_custom_prng():
def f(i):
return PRNGKey(i)
out, f_vjp = vjp(f, 1) # Fails!
Would it be possible to make new-style KeyArray
s py-trees? And then arm PRNGKey
with a custom derivative that sends a zero cotangent back?
Also, how do I produce a zero cotangent for a new style key array?
In other words, why do you need to do
random_wrap(array, impl)
instead ofPRNGKeyArray(impl, array)
?
Calling random_wrap(array, impl)
binds the random_wrap_p
primitive. When we're staging, we can capture it as such, and so we can later lower it, as well as lower operations on its output (such as slicing; see the gather
below):
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x):
... k = jax._src.prng.random_wrap(x, impl=jax._src.prng.threefry_prng_impl)
... k = k[:2, :3]
... return jax.vmap(jax.vmap(jax.random.bernoulli))(k)
...
>>> with jax.enable_custom_prng():
... print(jax.make_jaxpr(f)(jnp.ones((5, 6, 2), jnp.uint32)))
...
{ lambda ; a:u32[5,6,2]. let
b:key<fry>[5,6] = random_wrap[impl=fry] a
c:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
d:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
e:i32[2] = concatenate[dimension=0] c d
f:key<fry>[2,3] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(2, 3)
unique_indices=True
] b e
g:u32[2,3] = random_bits[bit_width=32 shape=()] f
h:u32[2,3] = shift_right_logical g 9
i:u32[2,3] = or h 1065353216
j:f32[2,3] = bitcast_convert_type[new_dtype=float32] i
k:f32[2,3] = sub j 1.0
l:f32[] = sub 1.0 0.0
m:f32[2,3] = mul k l
n:f32[2,3] = add m 0.0
o:f32[2,3] = max 0.0 n
p:bool[2,3] = lt o 0.5
in (p,) }
>>>
When evaluating eagerly, the primitive is indeed implemented as the expression PRNGKeyArray(impl, array)
.
Would it be possible to make new-style
KeyArray
s py-trees?
Not quite, because they aren't equivalent to (some tuple of) basic arrays. It's maybe easiest to consider a few examples of what would go wrong if you made them pytrees (naturally, ones that flatten to the underlying uint32
array). One example is that we can vmap a function over them down to the last dimension, then operate on the key bits—something we want to disallow. Another example is that if we tree_map
over them, we'd accidentally "unwrap" them back down to uint32
arrays.
Also, how do I produce a zero cotangent for a new style key array?
What do you do for current uint32
-style raw key arrays?
This seems to be related: [...]
That's a bug! Thanks for catching it. Filed #14856
We should tag the tracking issue too: #9263
This is a great thread Neil! Thanks for kicking the tires and asking useful questions.
It's maybe easiest to consider a few examples of what would go wrong if you made them pytrees (naturally, ones that flatten to the underlying
uint32
array). One example is that we can vmap a function over them down to the last dimension, then operate on the key bits—something we want to disallow. Another example is that if wetree_map
over them, we'd accidentally "unwrap" them back down touint32
arrays.
Just a thought, but does Jax support structured arrays? If so, you could make all of the key information a single record. Then, it would be essentially atomic in the eyes of vmap and tree-map.
A related, but less elegant solution would be to register new dtypes corresponding to the key types of your PRNGs.
What do you do for current
uint32
-style raw key arrays?
I don't do anything yet. I'm investigating storing the key array that I used on various components of my model in the inference output. In order for that to work, key arrays need to be pytrees, and I need to be able to produce zero cotangents for them.
It does seem a bit odd that I can JIT a function that accepts a seed dynamically, but I can't jit a function that accepts a key array dynamically.
Just a thought, but does Jax support structured arrays? If so, you could make all of the key information a single record.
It doesn't, but we have a semi-internal notion of "opaque dtypes" on which the key arrays (and one or two other such things) are built. See #11768, renamed in #12170.
The two are slightly different. Structured arrays allow for a standard numpy array of structs, whereas opaque dtypes allow for the data an array to be more arbitrary in structure, and allows for the array to customize what operations it supports. With opaque dtypes, our underlying data doesn't have to be an array-of-records, slicing doesn't have to be slicing into such an array-of-records, various array operations like might not be allowed, etc.
Then, it would be essentially atomic in the eyes of vmap and tree-map.
Yeah, there might be several ways to achieve this type of "atomicity in the eyes of vmap." The opaque dtypes mechanism is sufficient but probably not necessary. Another approach we've kicked around is a generalization of pytrees that involves instantiating typeclasses for different transformations (e.g. vmappable means roughly that you have a pair of rules from_elt
and to_elt
). This isn't a complete exploration though, especially not for all transformations.
It does seem a bit odd that I can JIT a function that accepts a seed dynamically, but I can't jit a function that accepts a key array dynamically.
How do you mean? This sounds like another potential bug. I can do this:
>>> with jax.enable_custom_prng():
... print(jax.jit(jax.random.bernoulli)(jax.random.PRNGKey(3)))
...
False
Thanks for the very informative explanation!
How do you mean? This sounds like another potential bug. I can do this:
I guess I misunderstood the consequences of key arrays not being pytrees. If they're not a pytree, how are they dynamically passed to a jitted function?
And an unrelated question, after #14856 is fixed, how should I produce a zero cotangent for a key array? For a normal pytree, you would do tree_map(lambda x: jnp.zeros_like(x), some_tree)
.
If they're not a pytree, how are they dynamically passed to a jitted function?
They're indeed not a pytree, but they do correspond to a "JAX type" (that error you were getting with the VJP bug is spurious). This is part of what the internal "opaque dtype" mechanism sets up. In other words: we made it work!
And an unrelated question, after #14856 is fixed, how should I produce a zero cotangent for a key array?
I'm not sure that it makes sense to make one? Key arrays correspond to uint-dtyped arrays under the hood, but either way neither of those forms a vector space. If a function returns a key array, what does it means to differentiate it? Int dtypes present a special case that we (perhaps questionably) decided to handle, and so zeros_like
works there.
Pragmatically: if you're taking VJP of a function whose output has some key arrays and some numeric ones, you could wrap it in one that only outputs numeric ones (and drops the keys) and take VJP of that?
There's also that has_aux
option to jax.vjp
. Does that get you what you need?
I do appreciate your point that the interaction between key arrays and our AD API changes as we upgrade to typed key arrays, and some user-side code might need rewriting. We'll have to think on whether we can make the upgrade easier...
I do appreciate your point that the interaction between key arrays and our AD API changes as we upgrade to typed key arrays, and some user-side code might need rewriting. We'll have to think on whether we can make the upgrade easier...
Right. I may be able to work around this, but it's going to be complicated.
One pattern that's pretty common in my code is to have classes like this:
@dataclass
class GeometrySamplerLoss:
attention_prediction_error: RealArray
rivals_prediction_error: RealArray
attention_curvature: RealArray
rivals_curvature: RealArray
@classmethod
def zeros(cls) -> GeometrySamplerLoss:
z = jnp.asarray(0.0)
return GeometrySamplerLoss(z, z, z, z)
@classmethod
def cotangent(cls) -> GeometrySamplerLoss:
o = jnp.asarray(1.0)
z = jnp.asarray(0.0)
return GeometrySamplerLoss(o, o, z, z)
When inferring a trajectory using scan, I use the zeros
method to initialize the state, and the contangent
method to produce a custom cotangent to pushed into a VJP to train these model components.
I understand your idea about not adding key arrays to model component outputs, but instead to create auxilliary outputs. This will complicate my design. I'll have to have a pair of outputs—one for regular things, and one for auxilliary things. (I guess while I'm at it, I can put all of the non-diferrentiated things in the auxxilliary output, which is a benefit.)
Alternatively, we could consider something like jax.custom_derivatives.zero_cotangent
:
def zero_cotangent(x: T) -> T:
"""Return a zero cotangent like x."""
Is that what interpreters.ad.instantiate_zeros
does?
They're indeed not a pytree, but they do correspond to a "JAX type"
Maybe I don't understand the definition of "pytree". I thought anything that supports tree_util.tree_flatten
is a pytree? Or does pytree imply an aggregate type? So Array
is not a pytree for that reason?
Maybe I don't understand the definition of "pytree". [...]
I rather meant more narrowly that the type is not registered with tree_util.register_pytree_node[_class]
. By your more basic definition, key arrays are indeed pytrees (specifically, leaves, just like other arrays are).
Right. I may be able to work around this, but it's going to be complicated.
Yeah, I can understand that this is more complex to express, at least relative to what you could write earlier. We'll have to think about this, but perhaps slightly orthogonally.
For example, maybe we should allow None
or a SymbolicZero(shape, dtype)
as a generalized convenient shorthand for (co-)tangents that indicate "this is not involved in differentiation." The shorthand here is really just to avoid the has_aux
and closure rewriting, but it would have essentially equivalent meaning. And it could be nice for more than only key arrays, too.
Alternatively, we could consider something like
jax.custom_derivatives.zero_cotangent
This sounds like it'd be along the lines of the above, in that it could be how you'd obtain such a hypothetical SymbolicZero(shape, dtype)
. We might also want it for tangents. Maybe we could call it symbolic_zero_perturbation_like
or something that's less verbose but along the same lines.
I'm thinking out loud here.
Is that what
interpreters.ad.instantiate_zeros
does?
Kind of, but not exactly in that it doesn't know about primal-tangent correspondences. Anyway, that's an internal helper, which will be rendered only available via jax._src
at the next opportunity. I wouldn't depend on it.
This sounds like it'd be along the lines of the above, in that it could be how you'd obtain such a hypothetical
SymbolicZero(shape, dtype)
. We might also want it for tangents. Maybe we could call itsymbolic_zero_perturbation_like
or something that's less verbose but along the same lines.I'm thinking out loud here.
Awesome.
I think there's a lot of beauty to the way that you can return None
out of a custom VJP to represent a zero cotangents for a non-differentiated primal input. What I'm looking for is a way to pass something into a custom VJP to represent a zero cotangent of non-differentiated primal output.
So, I really like your idea of SymbolicZero
. Like you say, it would be nice for it to work for more than only key arrays—ideally, all "Jax types"?
Anyway, that's an internal helper, which will be rendered only available via
jax._src
at the next opportunity. I wouldn't depend on it.
I'm definitely not using that! :smile: I only found it because I guessed that you might be producing zero cotangents somewhere in Jax code. If you already have a zero cotangent function inside Jax, I figured it would strengthen the argument for exposing such a function.
I brought up both None
and SymbolicZero
because both come up in custom JVPs as of #14570. Symbolic zeros enter the rule, but the rule can return None
s if preferred. Plain None
s are fine if information like the shape and dtype aren't needed or can be inferred. I haven't entirely thought through whether a None
is enough when invoking a VJP.
If you already have a zero cotangent function inside Jax, I figured it would strengthen the argument for exposing such a function.
To be sure, I think we've determined that this isn't quite the function you're looking for anyway – is that right?
That is, your request is not "let me obtain a proper zero cotangent for this arbitrary primal array." Instead it is "let me indicate, by passing a sentinel value in place of a cotangent, an intent not to perturb at this output position."
(In fact, there isn't a cotangent space for your primal array, and in particular no zero contangent.)
Is that what interpreters.ad.instantiate_zeros does?
Kind of, but not exactly in that it doesn't know about primal-tangent correspondences. Anyway, that's an internal helper, which will be rendered only available via jax._src at the next opportunity. I wouldn't depend on it.
I do depend on this... :D Please keep in available somewhere!
Plain
None
s are fine if information like the shape and dtype aren't needed or can be inferred. I haven't entirely thought through whether aNone
is enough when invoking a VJP.
I'm pretty sure plain None
s are totally fine. That's the API offered by equinox.filter_vjp
and it works smoothly. (After all you can always just reconstruct the symbolic zero with a Zero(core.get_aval(primal).at_least_vspace())
.)
That is, your request is not "let me obtain a proper zero cotangent for this arbitrary primal array." Instead it is "let me indicate, by passing a sentinel value in place of a cotangent, an intent not to perturb at this output position."
Yes, exactly!
Ideally, the system should work on components:
primals_out, f_vjp = vjp(f, *primals_in)
cotangents_out = f_vjp(*cotangents_in)
Where cotangents_in
is a tuple of pytrees of cotangent inputs, some of which may contain symbolic zero components. For example, cotangents_in[3]
may be equal to (jnp.ones(3), {'foo': SymbolicZero(PRNGKey(2))})
.
This goes beyond what's supported by None
in a custom VJP backwards pass (which can only send None
for an entire cotangent output—not for a component. My ideal interface is therefore: SymbolicZero(any_pytree)
where any_pytree
could be an array, a key array, or any aggregate pytree type.
I'm pretty sure plain
None
s are totally fine. That's the API offered byequinox.filter_vjp
and it works smoothly
So you support None
even in place sub-elements of cotangents (as in my example above)?
Closing because at this point we have jax.random.{key_data,key_impl}
for unwrapping and jax.random.wrap_key_data
for wrapping.
I've been working on producing a MWE for this issue for the last month or so. Part of this is setting everything up identically right before the freeze.
What I want to do is run the program that fails and print out the key array, and then produce a MWE where I reconstruct that key array. The problem is that I'm using the new custom PRNG and these don't seem to have any interface for construction from an array.
I don't want to step on any toes, but it seems to me that the design complicates this kind of use pattern. There are some excellent aspects to the current design. Jax does an excellent job of having small interfaces, and the new PRNG is no exception. However, it's just one function too small for my needs.
Trying to shoehorn this feature request to fit the pimpl idiom looks like it's going to be awkward. It would end up being something like
PRNGKeyArray(impl, key_array)
, which means exposing the pimpl objects (threefry_prng_impl
, etc.), andPRNGKeyArray
, and providing access to the pimpl objects—all things that go against the whole pimpl idea in the first place. Or perhaps adding methods to construct from a base array for each implementation?I want to humbly suggest removing the pimpl pattern, and switching to using ordinary classes and inheritance. Pimpl was popularized in C++ as a compilation firewall (for example, this excellent description by C++ standards chair Herb Sutter). The idea is that it allows you to completely hide private members (methods and data) in compiled source, and expose a minimal header with only public methods. That's why it rarely turns up in Python since there is no distinction between source and header.
The ordinary Python way to do this is:
Then, you can reproduce training examples by dumping:
and then restoring:
I'm happy to code it up, and there would be no user-facing interface changes (except for exposing the key classes), so nothing would break.