Closed arogozhnikov closed 5 years ago
Thanks for the question! Clearly that's an important one for users to ask themselves as they consider what libraries to use to do their work.
However, instead of making comparisons to other libraries, we'd rather focus on explaining what JAX can do. At its core, JAX is about transforming functions, including via forward- and reverse-mode autodiff, jit
compilation with end-to-end XLA optimization to multiple backends, and vmap
for things like batching and per-example gradients. JAX makes all those transformations arbitrarily composable, which is how we build compound transformations like fast Hessian calculations (which might use one level of forward-mode autodiff, one level of reverse-mode autodiff, two levels of vmap
, and a level of jit
):
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
So while JAX provides a NumPy API (like many libraries do!) and, because we use jit
in library functions, it can execute code on accelerators for individual calls, that's really just the starting point. JAX is about compositional transformations.
Hopefully the README provides more detail on what JAX can do. We'd like to keep improving the README, so if you have questions about it or suggestions on how to make it more clear, we'd love to hear them! Having a "showcase" is a great idea, and we're working on more examples. But we'll leave comparisons to specific libraries up to the users.
Cupy is quite stable and efficient "numpy for GPU" (which has no restrictions mentioned in readme), chainer over cupy provides necessary audo-diff and primitives for deep learning. There are also other alternatives.
It would be nice to have showcases when jax is expected to be beneficial compared to already existing tools.
Thanks!