JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
What’s new is that JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
Multiplying Matrices
We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see Common Gotchas in JAX.
key = random.key(0)
x = random.normal(key, (10,))
print(x)
Let’s dive right in and multiply two big matrices.
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
13.5 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
We added that block_until_ready because JAX uses asynchronous execution by default (see Asynchronous dispatch).
JAX NumPy functions work on regular NumPy arrays.
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
80 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
That’s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put().
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
15.8 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The output of device_put() still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior of device_put() is equivalent to the function jit(lambda x: x), but it’s faster.
If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU. See Is JAX faster than NumPy? for more comparison of performance characteristics of NumPy and JAX
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:
JAX is NumPy on the CPU, GPU, and TPU
DESCRIPTION: JAX Quickstart
Open in Colab here Open in Kaggle
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
What’s new is that JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
Multiplying Matrices
We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see Common Gotchas in JAX.
Let’s dive right in and multiply two big matrices.
We added that
block_until_ready
because JAX uses asynchronous execution by default (see Asynchronous dispatch).JAX NumPy functions work on regular NumPy arrays.
That’s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using
device_put()
.The output of
device_put()
still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior ofdevice_put()
is equivalent to the functionjit(lambda x: x)
, but it’s faster.If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU. See Is JAX faster than NumPy? for more comparison of performance characteristics of NumPy and JAX
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:
jit()
, for speeding up your codegrad()
, for taking derivativesvmap()
, for automatic vectorization or batching.URL: JAX Quickstart
Suggested labels