Closed adam-hartshorne closed 9 months ago
Yeah, I was not previously aware of https://github.com/rdyro/torch2jax. Wish I had seen it earlier! Glancing through their documentation it looks like they are using dispatch to call into PyTorch execution from JAX code. OTOH, samuela/torch2jax traces PyTorch execution, resulting in a JAX-native computation graph. As a result, I do not believe rdyro/torch2jax can support jit, vmap, grad, and so forth.
Looking at the readme,
jax.hessian(f) will not work since torch2jax uses forward differentiation, but the same functionality can be achieved using jax.jacobian(jax.jacobian(f))
Forward mode autodiff is slower than reverse mode autodiff as long as you have more inputs than outputs, as is the often case in machine learning (millions of parameters, scalar loss). samuela/torch2jax supports forward mode, backward mode, and any combination of the two by virtue of JAX supporting rich autodiff options.
input shapes are fixed for one wrapped function and cannot change, use torch2jax_with_vjp/torch2jax again if you need to alter the input shapes
samuela/torch2jax does not have this limitation.
in line with JAX philosphy, PyTorch functions must be non-mutable, torch.func has a good description of how to convert e.g., PyTorch models, to non-mutable formulation
samuela/torch2jax does not have this limitation.
in the Pytorch function all arguments must be tensors, all outputs must be tensors
samuela/torch2jax does not have this limitation.
all arguments must be on the same device and of the same datatype, either float32 or float64
samuela/torch2jax does not have this limitation.
an input/output shape (e.g. output_shapes= kw argument) representations (for flexibility in input and output structure) must be wrapped in torch.Size or jax.ShapeDtypeStruct
the current implementation does not support batching, that's on the roadmap
samuela/torch2jax supports batching and vmap out of the box.
the current implementation does not define the VJP rule, in current design, this has to be done in Python
samuela/torch2jax supports VJP and JVP out of the box.
Thanks for the swift response.
I presume you aren't aware of this,
https://github.com/rdyro/torch2jax
How does it compare to your work?