Closed adam-hartshorne closed 1 year ago
Oh, cool, great catch!
The package, from what I can tell, attempts to perform a dynamic torch computational graph conversion into a JAX computational graph.
I think doing it on the fly is a great idea since it moves all the computations to be done by JAX, but requires no conversions from the user directly.
However, it probably requires that conversion for most atoms (e.g., activation functions, layers) is implemented by hand (a fair few is already implemented). Another problem is that if the pytorch computation graph has a large number of nodes, JAX global computational graph optimizer (when using JIT), might take a really long time to compile the resulting model (e.g., the DARTS network). Finally, the other torch2jax most likely does not support in-place operations.
Unlike the other package, this torch2jax wraps the entire pytorch function as-is, so existing PyTorch code, especially custom or mutating code, can run without any problems.
As title suggests, be interested to get your thoughts on this newly released alternative.
https://github.com/samuela/torch2jax/tree/main