Closed greydanus closed 4 years ago
Thanks.
I am most concerned about adding wrappers of external libraries that have already been implemented with other forms of computing backend. We use a set of operators for particle mesh simulation and FFTs.
When we are talking about custom primitives, we also need to notify the auto-differ how to pick the tensor operators on the operands. How is this currently done in JAX?
I believe in autograd this was done with a vector space object that knows how to serialize any operand into a numpy array, after which numpy functions are used for inner products etc. This may not always be desirable -- e.g. if data has to be partitioned to several MPI ranks, then serialization to a single MPI rank is not even going to fit into the memory. We weren't able to use autograd due to this.
Another thing to worry about is whether these customized primitives support higher order differentiations. If the vjp function itself needs to be an external routine (not set of primitives) then higher order differentiation and auto-jvp are probably both broken? Is this a supported case?
Not sure if this is the right place: how can I define custom vmap primitives; in my case I am calling an external function that already supports batches and I want to vmap the code surrounding the call of this external primitive.
@jonasrauber that question is from a while ago, but the short answer is that custom_transforms
(as in from jax import custom_transforms
) is for doing this. To be improved and documented...
Just sketched out a custom VJPs API last night: https://gist.github.com/mattjj/2ba580930472e8e04c1759737268af92
The example there is trivial, and there's a bit more bookkeeping to be done to handle general code. But our initial thinking is that we can have a defvjp
and a defvjp_all
, to be used with @custom_transforms
, where the former lets you specify a vjp function for each positional argument and the latter lets you specify a vjp for all arguments at once. (Maybe we can also provide a defvjp_all_staged
if you want to compute some reduced residual information on the forward pass, rather than saving all the argument values.)
The funny bookkeeping in that gist is due to the fact that in JAX we usually don't specify VJPs (reverse-mode rules) directly, and instead only specify forward-mode rules; JAX generates reverse-mode autodiff through a composition of forward-mode, partial evaluation, and transposition transformations. But if you want to specify a VJP rule directly, that gist shows a trick to do it.
That was all a work-in-progress. We've got something better now!
Whoops, I didn't mean to close this in #818!
JAX supports custom primitives and vjps, just like Autograd did. Improvements: 1) add this to documentation 2) add a minimal example of this in the examples section 3) add a wrapper function if appropriate?