jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Provide utilities for creating pytrees filled with random samples #20594

Open SimonKoop opened 7 months ago

SimonKoop commented 7 months ago

The utilities provided by jax.tree_utils, together with the flexibility of jax transformations, make it easy to write code that works not just for arrays but for general pytrees. E.g. JAX makes it easy to add (the arrays in) two pytrees together, scale (the arrays of) a pytree by a scalar, or compute the gradient of some function with respect to (the arrays in) a pytree. One operation that is currently more involved is to sample a pytree full of random numbers from some distributions. That is, the following code is relatively involved to generalize to arbitrary pytrees:

new_array = old_array + .2 * jax.random.normal(key=my_key, shape=old_array.shape)

This could be made easy by introducing utilities like the following:

from typing import Callable, Any
from functools import partial

import jax
from jax import tree_util
from jax import numpy as jnp
import numpy as np

Pytree = Any

def is_array_like(leaf):
    return isinstance(leaf, (jax.Array, np.ndarray, np.generic, float, complex, bool, int))

def _process_leaf(leaf, key, sampler, sampling_criterion):
    if sampling_criterion(leaf):
        shape = jnp.shape(leaf)
        return sampler(key=key, shape=shape)
    return leaf

def sample_like(key:jax.Array, like:Pytree, sampler:Callable, sampling_criterion:Callable=is_array_like, **sampler_kwargs):
    """ 
    Sample random numbers using sampler to fill a pytree with the same structure as like
    :parameter key: prng key
    :parameter like: pytree determining the structure of the output
    :parameter sampler: callable taking two arguments
        - key: a prng key
        - shape: shape of the output array
        used for creating the random samples
        e.g. jax.random.normal
    :parameter sampling_criterion: Callable that takes leaves and returns a boolean indicating whether 
        a random sample should be drawn for this leaf
        the default criterion returns whether the leaf is array-like
    :parameter sampler_kwargs: any key-word arguments to be passed to the sampler
    :returns: pytree with the same structure as like
        any leaf for which sampling_criterion returns True will be sampled from sampler
        any other leaf will be kept
    """
    num_leaves = len(jax.tree_util.tree_leaves(like))
    keys = jax.random.split(key, num_leaves)
    key_tree = tree_util.tree_unflatten(
        treedef=tree_util.tree_structure(like),
        leaves=keys
    )
    sample = tree_util.tree_map(
        partial(
            _process_leaf, 
            sampler=partial(sampler, **sampler_kwargs), 
            sampling_criterion=sampling_criterion
            ),
        like,
        key_tree
    )
    return sample

normal_like = partial(sample_like, sampler=jax.random.normal)

One motivating example is adding noise to the gradients of some neural network.

jakevdp commented 7 months ago

Thanks for the request! My first inclination here is that it's best for us to leave this up to the user. Your solution based on jax.tree.map and jax.tree.unflatten seems pretty reasonable for your particular use-case, and this is a circumstance where every user will probably want something slightly different (e.g. a different distribution or scale per leaf, etc.) and it might quickly become a very large API.

The fact that you were able to do what you need in a dozen or so lines of code suggests others might be able to as well – what do you think?

SimonKoop commented 7 months ago

Thank you for the quick response! I can see how this could easily become too large of an API. I do think adding specifically Gaussian noise has many applications, so perhaps a slightly generalized JAX version of torch.randn_like that works for pytrees, could already satisfy a large portion of use cases without requiring a large API?

SimonKoop commented 7 months ago

Alternatively, something a bit more high level could work as well: something like prng_tree_map(f, tree, *rest, key, is_leaf=None) where the value at every leaf is given by f(x, *xs, key=leaf_key) with leaf_key being deterministically generated from the key that was the input to prng_tree_map, based on the structure of tree. This would allow users to map all sorts of "random functions" over trees without requiring an API.

Edit: something like

def prng_tree_map(f, tree, *rest, key, is_leaf=None):
  leaves, treedef = jax.tree_util.tree_flatten(tree, is_leaf)
  all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  keys = jax.random.split(key, len(leaves))
  return treedef.unflatten(f(*xs, key=leaf_key) for xs, leaf_key in zip(zip(*all_leaves), keys))

Edit 2: if you provide an analogue for tree_map_with_path too, your example of using a different distribution per leaf can also easily be implemented. Code for prng_tree_map_with_path could be

def prng_tree_map_with_path(f, tree, *rest, key, is_leaf=None):
  keypath_leaves, treedef = jax.tree_util.tree_flatten_with_path(tree, is_leaf)
  keypath_leaves = list(zip(*keypath_leaves))
  all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
  keys=jax.random.split(key, len(keypath_leaves[0]))
  return treedef.unflatten(f(*xs, key=leaf_key) for xs, leaf_key in zip(zip(*all_keypath_leaves), keys))
jakevdp commented 7 months ago

I'm not convinced of the need to provide a set API for something that can be expressed in a few lines using existing utilities. Maybe as a compromise, we could provide these examples as recipes in the tree_util documentation – what do you think?

SimonKoop commented 7 months ago

As a recipe could be nice. Something like prng_tree_map and prng_tree_map_with_path would seem to fit in fine with tree_util, but providing them as examples instead could help people familiarize themselves with pytrees. The number of lines argument goes for most functions in tree_utils though.

jakevdp commented 7 months ago

The number of lines argument goes for most functions in tree_utils though.

Lines of code is not the only consideration: most functions in jax.tree_util with succinct implementations are implemented in terms of non-public APIs (like references to default_registy, which is private by design). Without these utilities, it would be impossible to e.g. flatten a tree using public APIs.

I think it's a good rule-of-thumb to avoid expanding a package's API surface with utilities that can be expressed succinctly in terms of already available public APIs.

carlosgmartin commented 7 months ago

A possible recipe could be the following:

import jax

def random_split_tree(key, tree):
    treedef = jax.tree.structure(tree)
    keys = jax.random.split(key, treedef.num_leaves)
    return jax.tree.unflatten(treedef, keys)

def random_like(key, tree, fun=jax.random.normal):
    keys = random_split_tree(key, tree)
    return jax.tree.map(lambda key, leaf: fun(key=key, shape=leaf.shape), keys, tree)
carlosgmartin commented 4 months ago

By the way, optax has a function tree_random_like that does just this.