I use autograd to calculate partial derivatives of functions of two variables (x, y). Due to the end of support for autograd, I'm trying to get the same results using jax.
I use similar functions obtained by automatic differentiation in other parts of the program as wrappers, and then to obtain the final results I substitute the values of the NumPy arrays.
I haven't found a way to port this type of two-variable functions from autograd to jax with similar performance.
Examples:
autograd (ex1.py)
import numpy as np
from autograd import elementwise_grad as egrad
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
egrad(egrad(egrad(egrad(w, dx), dx), dx), dx)(x, y)
+ 2 * egrad(egrad(egrad(egrad(w, dx), dx), dy), dy)(x, y)
+ egrad(egrad(egrad(egrad(w, dy), dy), dy), dy)(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = np.arange(10_000, dtype=np.float64)
y = np.arange(10_000, dtype=np.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
import jax
from jax import grad, vmap
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
vmap(grad(grad(grad(grad(w, dx), dx), dx), dx))(x, y)
+ 2 * vmap(grad(grad(grad(grad(w, dx), dx), dy), dy))(x, y)
+ vmap(grad(grad(grad(grad(w, dy), dy), dy), dy))(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = jnp.arange(10_000, dtype=jnp.float64)
y = jnp.arange(10_000, dtype=jnp.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
Description
I use autograd to calculate partial derivatives of functions of two variables (x, y). Due to the end of support for autograd, I'm trying to get the same results using jax.
These functions have the form:
$$\nabla^4 w = \cfrac{\partial^4 w}{\partial x^4} + 2\cfrac{\partial^4 w}{\partial x^2\partial y^2} + \cfrac{\partial^4 w}{\partial y^4}$$
where $w = w(x,y)$.
I use similar functions obtained by automatic differentiation in other parts of the program as wrappers, and then to obtain the final results I substitute the values of the NumPy arrays.
I haven't found a way to port this type of two-variable functions from autograd to jax with similar performance.
Examples:
autograd (ex1.py)
jax (ex2.py)
The program using jax is almost 9x slower than the version using autograd. In more complicated programs the differences are much greater.
System info (python version, jaxlib version, accelerator, etc.)