rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

Any thoughts on this torch2jax alternative? #8

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

As title suggests, be interested to get your thoughts on this newly released alternative.

https://github.com/samuela/torch2jax/tree/main

rdyro commented 1 year ago

Oh, cool, great catch!

The package, from what I can tell, attempts to perform a dynamic torch computational graph conversion into a JAX computational graph.

I think doing it on the fly is a great idea since it moves all the computations to be done by JAX, but requires no conversions from the user directly.

However, it probably requires that conversion for most atoms (e.g., activation functions, layers) is implemented by hand (a fair few is already implemented). Another problem is that if the pytorch computation graph has a large number of nodes, JAX global computational graph optimizer (when using JIT), might take a really long time to compile the resulting model (e.g., the DARTS network). Finally, the other torch2jax most likely does not support in-place operations.

Unlike the other package, this torch2jax wraps the entire pytorch function as-is, so existing PyTorch code, especially custom or mutating code, can run without any problems.