Open SimonKoop opened 7 months ago
Thanks for the request! My first inclination here is that it's best for us to leave this up to the user. Your solution based on jax.tree.map
and jax.tree.unflatten
seems pretty reasonable for your particular use-case, and this is a circumstance where every user will probably want something slightly different (e.g. a different distribution or scale per leaf, etc.) and it might quickly become a very large API.
The fact that you were able to do what you need in a dozen or so lines of code suggests others might be able to as well – what do you think?
Thank you for the quick response! I can see how this could easily become too large of an API. I do think adding specifically Gaussian noise has many applications, so perhaps a slightly generalized JAX version of torch.randn_like that works for pytrees, could already satisfy a large portion of use cases without requiring a large API?
Alternatively, something a bit more high level could work as well: something like prng_tree_map(f, tree, *rest, key, is_leaf=None)
where the value at every leaf is given by f(x, *xs, key=leaf_key)
with leaf_key being deterministically generated from the key that was the input to prng_tree_map, based on the structure of tree. This would allow users to map all sorts of "random functions" over trees without requiring an API.
Edit: something like
def prng_tree_map(f, tree, *rest, key, is_leaf=None):
leaves, treedef = jax.tree_util.tree_flatten(tree, is_leaf)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
keys = jax.random.split(key, len(leaves))
return treedef.unflatten(f(*xs, key=leaf_key) for xs, leaf_key in zip(zip(*all_leaves), keys))
Edit 2: if you provide an analogue for tree_map_with_path too, your example of using a different distribution per leaf can also easily be implemented. Code for prng_tree_map_with_path could be
def prng_tree_map_with_path(f, tree, *rest, key, is_leaf=None):
keypath_leaves, treedef = jax.tree_util.tree_flatten_with_path(tree, is_leaf)
keypath_leaves = list(zip(*keypath_leaves))
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
keys=jax.random.split(key, len(keypath_leaves[0]))
return treedef.unflatten(f(*xs, key=leaf_key) for xs, leaf_key in zip(zip(*all_keypath_leaves), keys))
I'm not convinced of the need to provide a set API for something that can be expressed in a few lines using existing utilities. Maybe as a compromise, we could provide these examples as recipes in the tree_util
documentation – what do you think?
As a recipe could be nice. Something like prng_tree_map and prng_tree_map_with_path would seem to fit in fine with tree_util, but providing them as examples instead could help people familiarize themselves with pytrees. The number of lines argument goes for most functions in tree_utils though.
The number of lines argument goes for most functions in tree_utils though.
Lines of code is not the only consideration: most functions in jax.tree_util
with succinct implementations are implemented in terms of non-public APIs (like references to default_registy
, which is private by design). Without these utilities, it would be impossible to e.g. flatten a tree using public APIs.
I think it's a good rule-of-thumb to avoid expanding a package's API surface with utilities that can be expressed succinctly in terms of already available public APIs.
A possible recipe could be the following:
import jax
def random_split_tree(key, tree):
treedef = jax.tree.structure(tree)
keys = jax.random.split(key, treedef.num_leaves)
return jax.tree.unflatten(treedef, keys)
def random_like(key, tree, fun=jax.random.normal):
keys = random_split_tree(key, tree)
return jax.tree.map(lambda key, leaf: fun(key=key, shape=leaf.shape), keys, tree)
By the way, optax has a function tree_random_like
that does just this.
The utilities provided by jax.tree_utils, together with the flexibility of jax transformations, make it easy to write code that works not just for arrays but for general pytrees. E.g. JAX makes it easy to add (the arrays in) two pytrees together, scale (the arrays of) a pytree by a scalar, or compute the gradient of some function with respect to (the arrays in) a pytree. One operation that is currently more involved is to sample a pytree full of random numbers from some distributions. That is, the following code is relatively involved to generalize to arbitrary pytrees:
This could be made easy by introducing utilities like the following:
One motivating example is adding noise to the gradients of some neural network.