jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

JIT explicit pytree arguments #8173

Open AdrienCorenflos opened 3 years ago

AdrienCorenflos commented 3 years ago

Hi,

I would like to propose an improvement on the JIT API in order to make it closer to the vmap and pmap syntax.

At the moment functions that take a pytree as an argument either set the full pytree to be static or not static but without any distinction on the leaves themselves. This is problematic for creating user-friendly interfaces (see, e.g., https://github.com/blackjax-devs/blackjax/wiki/Meeting-minutes), and is a problem visibly encountered by many users who have found different ways of circumventing it (e.g., https://github.com/cgarciae/treeo of @cgarciae).

In order to make for a flexible -- yet friendly -- API, I believe users should be able to pass in an argument similar in style to the vmap in_axes argument. For example

def jit(fun, ..., static_flags=None):
  ...

This could be used as follows (currently supported)

@partial(jit, static_flags=[True, False])
def fun(x, y):
  return x + y

as well as (currently not supported)

@partial(jit, static_flags=[dict(x=True, y=False)])
def fun(dic):
  return dic["x"] + dic["y"]

hence providing an API similar to that of vmap both in terms of capabilities and design. Ideally, this format should also be supported for gradient transformations, but would probably require more thought as the output would then be harder to define.

WDYT?

cc @rlouf @junpenglao

cgarciae commented 3 years ago

To add another motivation, Equinox (@patrick-kidger) currently relies on apartition function to split the tree into a pure static tree and pure "dynamic" tree, if this feature becomes available it would only need to return a boolean tree, thus simplifying the API. It might also simply help simplify Treeo (the need for the merge function might be reduced).

It would be good if this becomes universally available (jit, pmap, grad, ect).

patrick-kidger commented 3 years ago

Thanks @cgarciae for the heads-up.

I'd suggest having a look at the filter_jit and filter_grad functions of Equinox. They take a filter_spec argument, which do precisely what this issues proposes. (+ a bit more -- they allow filter functions, as well as True/False literals, so that e.g. you can specify that you'd like to trace every array).

(I've been meaning to also add a filter_vmap, which would be able to handle a few edge cases that current vmap can't.)

Functionality wise I think this is a strict improvement over current jit and grad.

In short, I agree completely that this is something I'd like to see in core JAX.

AdrienCorenflos commented 3 years ago

Just to entice the maintainers to have a look, please note that I would be more than happy to open a PR and implement the whole thing if they agree that it is a good idea!