jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.81k forks source link

expose `random_wrap` and `random_unwrap` (plus discussion of early typed key surprises) #14046

Closed NeilGirdhar closed 1 year ago

NeilGirdhar commented 1 year ago

I've been working on producing a MWE for this issue for the last month or so. Part of this is setting everything up identically right before the freeze.

What I want to do is run the program that fails and print out the key array, and then produce a MWE where I reconstruct that key array. The problem is that I'm using the new custom PRNG and these don't seem to have any interface for construction from an array.

I don't want to step on any toes, but it seems to me that the design complicates this kind of use pattern. There are some excellent aspects to the current design. Jax does an excellent job of having small interfaces, and the new PRNG is no exception. However, it's just one function too small for my needs.

Trying to shoehorn this feature request to fit the pimpl idiom looks like it's going to be awkward. It would end up being something like PRNGKeyArray(impl, key_array), which means exposing the pimpl objects (threefry_prng_impl, etc.), and PRNGKeyArray, and providing access to the pimpl objects—all things that go against the whole pimpl idea in the first place. Or perhaps adding methods to construct from a base array for each implementation?

I want to humbly suggest removing the pimpl pattern, and switching to using ordinary classes and inheritance. Pimpl was popularized in C++ as a compilation firewall (for example, this excellent description by C++ standards chair Herb Sutter). The idea is that it allows you to completely hide private members (methods and data) in compiled source, and expose a minimal header with only public methods. That's why it rarely turns up in Python since there is no distinction between source and header.

The ordinary Python way to do this is:

class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
   def __init__(self, key_data: Array):
    assert not isinstance(key_data, core.Tracer)
    self._base_array = key_data

  @classmethod
  def _key_shape(cls) -> core.Shape:
    raise NotImplementedError

  def _split(self, num: int) -> Self:
    raise NotImplementedError

  def _random_bits(self, bit_width: int, shape: core.Shape) -> Array:
    raise NotImplementedError

  def _fold_in(self, data: Array) -> Self:
    raise NotImplementedError

class Threefry2x32Key(PRNGKeyArray):
    ...

class RBGKey(PRNGKeyArray):
    ...

class UnsafeRBGKey(PRNGKeyArray):
    ...

Then, you can reproduce training examples by dumping:

print(type(key_array))  # This is a function I want.
print(key_array.unsafe_raw_array())

and then restoring:

SomeKeyType(raw_array)  # This is the other function I want.

I'm happy to code it up, and there would be no user-facing interface changes (except for exposing the key classes), so nothing would break.

NeilGirdhar commented 1 year ago

Current workaround:

    print(key_array.impl)  # Get type.
    print(key_array.unsafe_raw_array())
    # and then...
    raw_array = jnp.array((2634740717, 3214329440), dtype=jnp.uint32)
    from jax._src.prng import PRNGKeyArray, threefry_prng_impl
    key_array = PRNGKeyArray(threefry_prng_impl, raw_array)  # Create using a type and value.
froystig commented 1 year ago

If you're importing _src.prng for workarounds, a better workaround may be:

jax._src.prng.random_wrap(raw_array, impl=impl)

In particular, this works under jit.

I'll try to write back soon about the broader questions up top. Just wanted to get this to you in the meantime. As always, thanks for putting thought into these things and for the discussion/feedback, Neil!

froystig commented 1 year ago

Coming back to this, the "workaround" might actually be an answer to part of the question. @NeilGirdhar – regardless of the idiom by which we set up PRNG implementations, my original thought was to expose random_wrap and random_unwrap (perhaps by a different name) as public API functions. Would that suffice for your needs as more than a workaround?

I'm also happy to separately review improvements to the internal idiom we use for PRNG implementations (e.g. a switch to inheritance), assuming it's a separate concern, and especially if it cleans things up or makes them easier!

The reason that I propose to have random_{,un}wrap be the way only in/out of key arrays follows my previous comment: these functions bind corresponding array-casting primitives, and in particular calling them is invariant under jit (and staging more generally). Constructing or unwrapping a Python PRNGKeyArray (or derivative thereof) is merely the corresponding impl rule for these primitives.

Does this seem right?

NeilGirdhar commented 1 year ago

Sorry for the delay. I wanted to invest some time in understanding random_wrap. I'm afraid I don't understand what JIT invariance means. My understanding of Jax internals still has a long way to go. I don't want to use up your valuable time, but why is it that simple construction doesn't work with the JIT? In other words, why do you need to do random_wrap(array, impl) instead of PRNGKeyArray(impl, array)?

Would that suffice for your needs as more than a workaround?

That means you'll be exposing the PRNGImpl classes like threefry_prng_impl? Yes, that definitely works.

I'm also happy to separately review improvements to the internal idiom we use for PRNG implementations (e.g. a switch to inheritance), assuming it's a separate concern, and especially if it cleans things up or makes them easier!

Yes, it's a separate concern! And I love your commitment to finding the best design.

NeilGirdhar commented 1 year ago

This seems to be related:

from jax import enable_custom_prng, vjp
from jax.random import PRNGKey

with enable_custom_prng():
    def f(i):
        return PRNGKey(i)

    out, f_vjp = vjp(f, 1)  # Fails!

Would it be possible to make new-style KeyArrays py-trees? And then arm PRNGKey with a custom derivative that sends a zero cotangent back?

Also, how do I produce a zero cotangent for a new style key array?

froystig commented 1 year ago

In other words, why do you need to do random_wrap(array, impl) instead of PRNGKeyArray(impl, array)?

Calling random_wrap(array, impl) binds the random_wrap_p primitive. When we're staging, we can capture it as such, and so we can later lower it, as well as lower operations on its output (such as slicing; see the gather below):

>>> import jax
>>> import jax.numpy as jnp
>>> def f(x):
...   k = jax._src.prng.random_wrap(x, impl=jax._src.prng.threefry_prng_impl)
...   k = k[:2, :3]
...   return jax.vmap(jax.vmap(jax.random.bernoulli))(k)
... 
>>> with jax.enable_custom_prng():
...     print(jax.make_jaxpr(f)(jnp.ones((5, 6, 2), jnp.uint32)))
... 
{ lambda ; a:u32[5,6,2]. let
    b:key<fry>[5,6] = random_wrap[impl=fry] a
    c:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    e:i32[2] = concatenate[dimension=0] c d
    f:key<fry>[2,3] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(2, 3)
      unique_indices=True
    ] b e
    g:u32[2,3] = random_bits[bit_width=32 shape=()] f
    h:u32[2,3] = shift_right_logical g 9
    i:u32[2,3] = or h 1065353216
    j:f32[2,3] = bitcast_convert_type[new_dtype=float32] i
    k:f32[2,3] = sub j 1.0
    l:f32[] = sub 1.0 0.0
    m:f32[2,3] = mul k l
    n:f32[2,3] = add m 0.0
    o:f32[2,3] = max 0.0 n
    p:bool[2,3] = lt o 0.5
  in (p,) }
>>> 

When evaluating eagerly, the primitive is indeed implemented as the expression PRNGKeyArray(impl, array).

Would it be possible to make new-style KeyArrays py-trees?

Not quite, because they aren't equivalent to (some tuple of) basic arrays. It's maybe easiest to consider a few examples of what would go wrong if you made them pytrees (naturally, ones that flatten to the underlying uint32 array). One example is that we can vmap a function over them down to the last dimension, then operate on the key bits—something we want to disallow. Another example is that if we tree_map over them, we'd accidentally "unwrap" them back down to uint32 arrays.

Also, how do I produce a zero cotangent for a new style key array?

What do you do for current uint32-style raw key arrays?

froystig commented 1 year ago

This seems to be related: [...]

That's a bug! Thanks for catching it. Filed #14856

froystig commented 1 year ago

We should tag the tracking issue too: #9263

This is a great thread Neil! Thanks for kicking the tires and asking useful questions.

NeilGirdhar commented 1 year ago

It's maybe easiest to consider a few examples of what would go wrong if you made them pytrees (naturally, ones that flatten to the underlying uint32 array). One example is that we can vmap a function over them down to the last dimension, then operate on the key bits—something we want to disallow. Another example is that if we tree_map over them, we'd accidentally "unwrap" them back down to uint32 arrays.

Just a thought, but does Jax support structured arrays? If so, you could make all of the key information a single record. Then, it would be essentially atomic in the eyes of vmap and tree-map.

A related, but less elegant solution would be to register new dtypes corresponding to the key types of your PRNGs.

What do you do for current uint32-style raw key arrays?

I don't do anything yet. I'm investigating storing the key array that I used on various components of my model in the inference output. In order for that to work, key arrays need to be pytrees, and I need to be able to produce zero cotangents for them.

It does seem a bit odd that I can JIT a function that accepts a seed dynamically, but I can't jit a function that accepts a key array dynamically.

froystig commented 1 year ago

Just a thought, but does Jax support structured arrays? If so, you could make all of the key information a single record.

It doesn't, but we have a semi-internal notion of "opaque dtypes" on which the key arrays (and one or two other such things) are built. See #11768, renamed in #12170.

The two are slightly different. Structured arrays allow for a standard numpy array of structs, whereas opaque dtypes allow for the data an array to be more arbitrary in structure, and allows for the array to customize what operations it supports. With opaque dtypes, our underlying data doesn't have to be an array-of-records, slicing doesn't have to be slicing into such an array-of-records, various array operations like might not be allowed, etc.

Then, it would be essentially atomic in the eyes of vmap and tree-map.

Yeah, there might be several ways to achieve this type of "atomicity in the eyes of vmap." The opaque dtypes mechanism is sufficient but probably not necessary. Another approach we've kicked around is a generalization of pytrees that involves instantiating typeclasses for different transformations (e.g. vmappable means roughly that you have a pair of rules from_elt and to_elt). This isn't a complete exploration though, especially not for all transformations.

It does seem a bit odd that I can JIT a function that accepts a seed dynamically, but I can't jit a function that accepts a key array dynamically.

How do you mean? This sounds like another potential bug. I can do this:

>>> with jax.enable_custom_prng():
...     print(jax.jit(jax.random.bernoulli)(jax.random.PRNGKey(3)))
... 
False
NeilGirdhar commented 1 year ago

Thanks for the very informative explanation!

How do you mean? This sounds like another potential bug. I can do this:

I guess I misunderstood the consequences of key arrays not being pytrees. If they're not a pytree, how are they dynamically passed to a jitted function?

And an unrelated question, after #14856 is fixed, how should I produce a zero cotangent for a key array? For a normal pytree, you would do tree_map(lambda x: jnp.zeros_like(x), some_tree).

froystig commented 1 year ago

If they're not a pytree, how are they dynamically passed to a jitted function?

They're indeed not a pytree, but they do correspond to a "JAX type" (that error you were getting with the VJP bug is spurious). This is part of what the internal "opaque dtype" mechanism sets up. In other words: we made it work!

And an unrelated question, after #14856 is fixed, how should I produce a zero cotangent for a key array?

I'm not sure that it makes sense to make one? Key arrays correspond to uint-dtyped arrays under the hood, but either way neither of those forms a vector space. If a function returns a key array, what does it means to differentiate it? Int dtypes present a special case that we (perhaps questionably) decided to handle, and so zeros_like works there.

Pragmatically: if you're taking VJP of a function whose output has some key arrays and some numeric ones, you could wrap it in one that only outputs numeric ones (and drops the keys) and take VJP of that?

There's also that has_aux option to jax.vjp. Does that get you what you need?

I do appreciate your point that the interaction between key arrays and our AD API changes as we upgrade to typed key arrays, and some user-side code might need rewriting. We'll have to think on whether we can make the upgrade easier...

NeilGirdhar commented 1 year ago

I do appreciate your point that the interaction between key arrays and our AD API changes as we upgrade to typed key arrays, and some user-side code might need rewriting. We'll have to think on whether we can make the upgrade easier...

Right. I may be able to work around this, but it's going to be complicated.

One pattern that's pretty common in my code is to have classes like this:

@dataclass
class GeometrySamplerLoss:
    attention_prediction_error: RealArray
    rivals_prediction_error: RealArray
    attention_curvature: RealArray
    rivals_curvature: RealArray

    @classmethod
    def zeros(cls) -> GeometrySamplerLoss:
        z = jnp.asarray(0.0)
        return GeometrySamplerLoss(z, z, z, z)

    @classmethod
    def cotangent(cls) -> GeometrySamplerLoss:
        o = jnp.asarray(1.0)
        z = jnp.asarray(0.0)
        return GeometrySamplerLoss(o, o, z, z)

When inferring a trajectory using scan, I use the zeros method to initialize the state, and the contangent method to produce a custom cotangent to pushed into a VJP to train these model components.

I understand your idea about not adding key arrays to model component outputs, but instead to create auxilliary outputs. This will complicate my design. I'll have to have a pair of outputs—one for regular things, and one for auxilliary things. (I guess while I'm at it, I can put all of the non-diferrentiated things in the auxxilliary output, which is a benefit.)

Alternatively, we could consider something like jax.custom_derivatives.zero_cotangent:

def zero_cotangent(x: T) -> T:
  """Return a zero cotangent like x."""

Is that what interpreters.ad.instantiate_zeros does?

They're indeed not a pytree, but they do correspond to a "JAX type"

Maybe I don't understand the definition of "pytree". I thought anything that supports tree_util.tree_flatten is a pytree? Or does pytree imply an aggregate type? So Array is not a pytree for that reason?

froystig commented 1 year ago

Maybe I don't understand the definition of "pytree". [...]

I rather meant more narrowly that the type is not registered with tree_util.register_pytree_node[_class]. By your more basic definition, key arrays are indeed pytrees (specifically, leaves, just like other arrays are).

Right. I may be able to work around this, but it's going to be complicated.

Yeah, I can understand that this is more complex to express, at least relative to what you could write earlier. We'll have to think about this, but perhaps slightly orthogonally.

For example, maybe we should allow None or a SymbolicZero(shape, dtype) as a generalized convenient shorthand for (co-)tangents that indicate "this is not involved in differentiation." The shorthand here is really just to avoid the has_aux and closure rewriting, but it would have essentially equivalent meaning. And it could be nice for more than only key arrays, too.

Alternatively, we could consider something like jax.custom_derivatives.zero_cotangent

This sounds like it'd be along the lines of the above, in that it could be how you'd obtain such a hypothetical SymbolicZero(shape, dtype). We might also want it for tangents. Maybe we could call it symbolic_zero_perturbation_like or something that's less verbose but along the same lines.

I'm thinking out loud here.

Is that what interpreters.ad.instantiate_zeros does?

Kind of, but not exactly in that it doesn't know about primal-tangent correspondences. Anyway, that's an internal helper, which will be rendered only available via jax._src at the next opportunity. I wouldn't depend on it.

NeilGirdhar commented 1 year ago

This sounds like it'd be along the lines of the above, in that it could be how you'd obtain such a hypothetical SymbolicZero(shape, dtype). We might also want it for tangents. Maybe we could call it symbolic_zero_perturbation_like or something that's less verbose but along the same lines.

I'm thinking out loud here.

Awesome.

I think there's a lot of beauty to the way that you can return None out of a custom VJP to represent a zero cotangents for a non-differentiated primal input. What I'm looking for is a way to pass something into a custom VJP to represent a zero cotangent of non-differentiated primal output.

So, I really like your idea of SymbolicZero. Like you say, it would be nice for it to work for more than only key arrays—ideally, all "Jax types"?

Anyway, that's an internal helper, which will be rendered only available via jax._src at the next opportunity. I wouldn't depend on it.

I'm definitely not using that! :smile: I only found it because I guessed that you might be producing zero cotangents somewhere in Jax code. If you already have a zero cotangent function inside Jax, I figured it would strengthen the argument for exposing such a function.

froystig commented 1 year ago

I brought up both None and SymbolicZero because both come up in custom JVPs as of #14570. Symbolic zeros enter the rule, but the rule can return Nones if preferred. Plain Nones are fine if information like the shape and dtype aren't needed or can be inferred. I haven't entirely thought through whether a None is enough when invoking a VJP.

If you already have a zero cotangent function inside Jax, I figured it would strengthen the argument for exposing such a function.

To be sure, I think we've determined that this isn't quite the function you're looking for anyway – is that right?

That is, your request is not "let me obtain a proper zero cotangent for this arbitrary primal array." Instead it is "let me indicate, by passing a sentinel value in place of a cotangent, an intent not to perturb at this output position."

(In fact, there isn't a cotangent space for your primal array, and in particular no zero contangent.)

patrick-kidger commented 1 year ago

Is that what interpreters.ad.instantiate_zeros does?

Kind of, but not exactly in that it doesn't know about primal-tangent correspondences. Anyway, that's an internal helper, which will be rendered only available via jax._src at the next opportunity. I wouldn't depend on it.

I do depend on this... :D Please keep in available somewhere!

Plain Nones are fine if information like the shape and dtype aren't needed or can be inferred. I haven't entirely thought through whether a None is enough when invoking a VJP.

I'm pretty sure plain Nones are totally fine. That's the API offered by equinox.filter_vjp and it works smoothly. (After all you can always just reconstruct the symbolic zero with a Zero(core.get_aval(primal).at_least_vspace()).)

NeilGirdhar commented 1 year ago

That is, your request is not "let me obtain a proper zero cotangent for this arbitrary primal array." Instead it is "let me indicate, by passing a sentinel value in place of a cotangent, an intent not to perturb at this output position."

Yes, exactly!

Ideally, the system should work on components:

primals_out, f_vjp = vjp(f, *primals_in)
cotangents_out = f_vjp(*cotangents_in)

Where cotangents_in is a tuple of pytrees of cotangent inputs, some of which may contain symbolic zero components. For example, cotangents_in[3] may be equal to (jnp.ones(3), {'foo': SymbolicZero(PRNGKey(2))}).

This goes beyond what's supported by None in a custom VJP backwards pass (which can only send None for an entire cotangent output—not for a component. My ideal interface is therefore: SymbolicZero(any_pytree) where any_pytree could be an array, a key array, or any aggregate pytree type.

I'm pretty sure plain Nones are totally fine. That's the API offered by equinox.filter_vjp and it works smoothly

So you support None even in place sub-elements of cotangents (as in my example above)?

froystig commented 1 year ago

Closing because at this point we have jax.random.{key_data,key_impl} for unwrapping and jax.random.wrap_key_data for wrapping.