google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29k stars 2.65k forks source link

Low calculation performance compared to autograd elementwise_grad #21436

Open krysros opened 1 month ago

krysros commented 1 month ago

Description

I use autograd to calculate partial derivatives of functions of two variables (x, y). Due to the end of support for autograd, I'm trying to get the same results using jax.

These functions have the form:

$$\nabla^4 w = \cfrac{\partial^4 w}{\partial x^4} + 2\cfrac{\partial^4 w}{\partial x^2\partial y^2} + \cfrac{\partial^4 w}{\partial y^4}$$

where $w = w(x,y)$.

I use similar functions obtained by automatic differentiation in other parts of the program as wrappers, and then to obtain the final results I substitute the values ​​of the NumPy arrays.

I haven't found a way to port this type of two-variable functions from autograd to jax with similar performance.

Examples:

autograd (ex1.py)

import numpy as np
from autograd import elementwise_grad as egrad

dx, dy = 0, 1

def nabla4(w):
    def fn(x, y):
        return (
            egrad(egrad(egrad(egrad(w, dx), dx), dx), dx)(x, y)
            + 2 * egrad(egrad(egrad(egrad(w, dx), dx), dy), dy)(x, y)
            + egrad(egrad(egrad(egrad(w, dy), dy), dy), dy)(x, y)
        )

    return fn

def f(x, y):
    return x**4 + 2 * x**2 * y**2 + y**4

x = np.arange(10_000, dtype=np.float64)
y = np.arange(10_000, dtype=np.float64)

w = [f] * 100  # In a real program, the elements of the list are various functions.

r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex1.py }

Days              : 0
Hours             : 0
Minutes           : 0
Seconds           : 0
Milliseconds      : 813
Ticks             : 8130392
TotalDays         : 9,41017592592593E-06
TotalHours        : 0,000225844222222222
TotalMinutes      : 0,0135506533333333
TotalSeconds      : 0,8130392
TotalMilliseconds : 813,0392

jax (ex2.py)

import jax
from jax import grad, vmap
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

dx, dy = 0, 1

def nabla4(w):
    def fn(x, y):
        return (
            vmap(grad(grad(grad(grad(w, dx), dx), dx), dx))(x, y)
            + 2 * vmap(grad(grad(grad(grad(w, dx), dx), dy), dy))(x, y)
            + vmap(grad(grad(grad(grad(w, dy), dy), dy), dy))(x, y)
        )

    return fn

def f(x, y):
    return x**4 + 2 * x**2 * y**2 + y**4

x = jnp.arange(10_000, dtype=jnp.float64)
y = jnp.arange(10_000, dtype=jnp.float64)

w = [f] * 100  # In a real program, the elements of the list are various functions.

r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex2.py }

Days              : 0
Hours             : 0
Minutes           : 0
Seconds           : 6
Milliseconds      : 906
Ticks             : 69064939
TotalDays         : 7,99362719907407E-05
TotalHours        : 0,00191847052777778
TotalMinutes      : 0,115108231666667
TotalSeconds      : 6,9064939
TotalMilliseconds : 6906,4939

The program using jax is almost 9x slower than the version using autograd. In more complicated programs the differences are much greater.

System info (python version, jaxlib version, accelerator, etc.)

jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (tags/v3.10.13-25-g07fbd8e9251-dirty:07fbd8e9251, Dec 28 2023, 15:38:17) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Vero', release='10', version='10.0.22631', machine='AMD64')
jakevdp commented 5 days ago

(answered in #22050)