samuela / torch2jax

Run PyTorch in JAX. 🤝
168 stars 5 forks source link

Alternative Torch2Jax #2

Closed adam-hartshorne closed 9 months ago

adam-hartshorne commented 9 months ago

I presume you aren't aware of this,

https://github.com/rdyro/torch2jax

How does it compare to your work?

samuela commented 9 months ago

Yeah, I was not previously aware of https://github.com/rdyro/torch2jax. Wish I had seen it earlier! Glancing through their documentation it looks like they are using dispatch to call into PyTorch execution from JAX code. OTOH, samuela/torch2jax traces PyTorch execution, resulting in a JAX-native computation graph. As a result, I do not believe rdyro/torch2jax can support jit, vmap, grad, and so forth.

Looking at the readme,

jax.hessian(f) will not work since torch2jax uses forward differentiation, but the same functionality can be achieved using jax.jacobian(jax.jacobian(f))

Forward mode autodiff is slower than reverse mode autodiff as long as you have more inputs than outputs, as is the often case in machine learning (millions of parameters, scalar loss). samuela/torch2jax supports forward mode, backward mode, and any combination of the two by virtue of JAX supporting rich autodiff options.

input shapes are fixed for one wrapped function and cannot change, use torch2jax_with_vjp/torch2jax again if you need to alter the input shapes

samuela/torch2jax does not have this limitation.

in line with JAX philosphy, PyTorch functions must be non-mutable, torch.func has a good description of how to convert e.g., PyTorch models, to non-mutable formulation

samuela/torch2jax does not have this limitation.

in the Pytorch function all arguments must be tensors, all outputs must be tensors

samuela/torch2jax does not have this limitation.

all arguments must be on the same device and of the same datatype, either float32 or float64

samuela/torch2jax does not have this limitation.

an input/output shape (e.g. output_shapes= kw argument) representations (for flexibility in input and output structure) must be wrapped in torch.Size or jax.ShapeDtypeStruct

the current implementation does not support batching, that's on the roadmap

samuela/torch2jax supports batching and vmap out of the box.

the current implementation does not define the VJP rule, in current design, this has to be done in Python

samuela/torch2jax supports VJP and JVP out of the box.

adam-hartshorne commented 9 months ago

Thanks for the swift response.