Open AdrienCorenflos opened 3 years ago
To add another motivation, Equinox (@patrick-kidger) currently relies on apartition
function to split the tree into a pure static tree and pure "dynamic" tree, if this feature becomes available it would only need to return a boolean tree, thus simplifying the API. It might also simply help simplify Treeo (the need for the merge
function might be reduced).
It would be good if this becomes universally available (jit, pmap, grad, ect).
Thanks @cgarciae for the heads-up.
I'd suggest having a look at the filter_jit
and filter_grad
functions of Equinox. They take a filter_spec
argument, which do precisely what this issues proposes. (+ a bit more -- they allow filter functions, as well as True
/False
literals, so that e.g. you can specify that you'd like to trace every array).
(I've been meaning to also add a filter_vmap
, which would be able to handle a few edge cases that current vmap
can't.)
Functionality wise I think this is a strict improvement over current jit
and grad
.
In short, I agree completely that this is something I'd like to see in core JAX.
Just to entice the maintainers to have a look, please note that I would be more than happy to open a PR and implement the whole thing if they agree that it is a good idea!
Hi,
I would like to propose an improvement on the JIT API in order to make it closer to the vmap and pmap syntax.
At the moment functions that take a pytree as an argument either set the full pytree to be static or not static but without any distinction on the leaves themselves. This is problematic for creating user-friendly interfaces (see, e.g., https://github.com/blackjax-devs/blackjax/wiki/Meeting-minutes), and is a problem visibly encountered by many users who have found different ways of circumventing it (e.g., https://github.com/cgarciae/treeo of @cgarciae).
In order to make for a flexible -- yet friendly -- API, I believe users should be able to pass in an argument similar in style to the vmap
in_axes
argument. For exampleThis could be used as follows (currently supported)
as well as (currently not supported)
hence providing an API similar to that of
vmap
both in terms of capabilities and design. Ideally, this format should also be supported for gradient transformations, but would probably require more thought as the output would then be harder to define.WDYT?
cc @rlouf @junpenglao