jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

`jax.jit` slows down the code a lot on function with simple array operations and `jnp.roll()` #24373

Open pmocz opened 4 days ago

pmocz commented 4 days ago

Description

I get significant 4x slowdown in JAX code when I add a @jax.jit to my main update function, which manipulates large arrays with element-wise math and jnp.roll()

A minimal reproducer is included below, where removing the @jax.jit around the update() function (line marked by a comment # XXX) speeds up the code a lot. The slowdown is not due to compile-time overhead. I'm quite puzzled by the behavior and think it may be a bug in JAX or XLA. What is the best way to get to the bottom of this issue? To reproduce, run python euler.py with and without the jit decorator around update():

import jax
import jax.numpy as jnp
import time

# simulation parameters
N = 1024
boxsize = 1.0
dx = boxsize / N
vol = dx**2
dt = 0.0001

@jax.jit
def get_conserved(rho, vx, vy, P):
    """Calculate the conserved variables from the primitive variables"""

    Mass = rho * vol
    Momx = rho * vx * vol
    Momy = rho * vy * vol
    Energy = (P / (5 / 3 - 1) + 0.5 * rho * (vx**2 + vy**2)) * vol

    return Mass, Momx, Momy, Energy

@jax.jit
def get_primitive(Mass, Momx, Momy, Energy):
    """Calculate the primitive variable from the conserved variables"""

    rho = Mass / vol
    vx = Momx / rho / vol
    vy = Momy / rho / vol
    P = (Energy / vol - 0.5 * rho * (vx**2 + vy**2)) * (5 / 3 - 1)

    return rho, vx, vy, P

@jax.jit
def get_gradient(f):
    """Calculate the gradients of a field"""

    f_dx = (jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)) / (2 * dx)
    f_dy = (jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)) / (2 * dx)

    return f_dx, f_dy

@jax.jit
def extrapolate_to_face(f, f_dx, f_dy):
    """Extrapolate the field from face centers to faces using gradients"""

    f_XL = f - f_dx * dx / 2
    f_XL = jnp.roll(f_XL, -1, axis=0)
    f_XR = f + f_dx * dx / 2

    f_YL = f - f_dy * dx / 2
    f_YL = jnp.roll(f_YL, -1, axis=1)
    f_YR = f + f_dy * dx / 2

    return f_XL, f_XR, f_YL, f_YR

@jax.jit
def apply_fluxes(F, flux_F_X, flux_F_Y):
    """Apply fluxes to conserved variables to update solution state"""

    F += -dt * dx * flux_F_X
    F += dt * dx * jnp.roll(flux_F_X, 1, axis=0)
    F += -dt * dx * flux_F_Y
    F += dt * dx * jnp.roll(flux_F_Y, 1, axis=1)

    return F

@jax.jit
def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P):
    """Calculate fluxes between 2 states"""

    # left and right energies
    en_L = P / (5 / 3 - 1) + 0.5 * rho_L * (vx_L**2 + vy_L**2)
    en_R = P / (5 / 3 - 1) + 0.5 * rho_R * (vx_R**2 + vy_R**2)

    # compute star (averaged) states
    rho_star = 0.5 * (rho_L + rho_R)
    momx_star = 0.5 * (rho_L * vx_L + rho_R * vx_R)
    momy_star = 0.5 * (rho_L * vy_L + rho_R * vy_R)
    en_star = 0.5 * (en_L + en_R)

    P_star = (5 / 3 - 1) * (en_star - 0.5 * (momx_star**2 + momy_star**2) / rho_star)

    flux_Mass = momx_star
    flux_Momx = momx_star**2 / rho_star + P_star
    flux_Momy = momx_star * momy_star / rho_star
    flux_Energy = (en_star + P_star) * momx_star / rho_star

    # add stabilizing diffusive term
    flux_Mass -= 0.5 * 0.5 * (rho_L - rho_R)
    flux_Momx -= 0.5 * 0.5 * (rho_L * vx_L - rho_R * vx_R)
    flux_Momy -= 0.5 * 0.5 * (rho_L * vy_L - rho_R * vy_R)
    flux_Energy -= 0.5 * 0.5 * (en_L - en_R)

    return flux_Mass, flux_Momx, flux_Momy, flux_Energy

@jax.jit  # <---  XXX Adding this line slows down the code a lot!!
def update(Mass, Momx, Momy, Energy):
    """Take a simulation timestep"""

    rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy)

    rho_dx, rho_dy = get_gradient(rho)
    vx_dx, vx_dy = get_gradient(vx)
    vy_dx, vy_dy = get_gradient(vy)

    rho_XL, rho_XR, rho_YL, rho_YR = extrapolate_to_face(rho, rho_dx, rho_dy)
    vx_XL, vx_XR, vx_YL, vx_YR = extrapolate_to_face(vx, vx_dx, vx_dy)
    vy_XL, vy_XR, vy_YL, vy_YR = extrapolate_to_face(vy, vy_dx, vy_dy)

    flux_Mass_X, flux_Momx_X, flux_Momy_X, flux_Energy_X = get_flux(
        rho_XL, rho_XR, vx_XL, vx_XR, vy_XL, vy_XR, P
    )
    flux_Mass_Y, flux_Momy_Y, flux_Momx_Y, flux_Energy_Y = get_flux(
        rho_YL, rho_YR, vy_YL, vy_YR, vx_YL, vx_YR, P
    )

    Mass = apply_fluxes(Mass, flux_Mass_X, flux_Mass_Y)
    Momx = apply_fluxes(Momx, flux_Momx_X, flux_Momx_Y)
    Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y)
    Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y)

    return Mass, Momx, Momy, Energy

def main():
    """Finite Volume simulation"""

    # Setup
    xlin = jnp.linspace(0.5 * dx, boxsize - 0.5 * dx, N)
    X, Y = jnp.meshgrid(xlin, xlin, indexing="ij")

    rho = 1.0 + (jnp.abs(Y - 0.5) < 0.25)
    vx = -0.5 + (jnp.abs(Y - 0.5) < 0.25)
    vy = 0.1 * jnp.sin(4 * jnp.pi * X)
    P = 2.5 * jnp.ones(X.shape)

    Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P)

    # Main Loop
    tic = time.time()
    for n_iter in range(40):

        Mass, Momx, Momy, Energy = jax.block_until_ready(
            update(Mass, Momx, Momy, Energy)
        )

        cell_updates = X.shape[0] * X.shape[1] * n_iter
        total_time = time.time() - tic
        mcups = cell_updates / (1e6 * total_time)
        print("  million cell updates / second: ", mcups)

    print("Total time: ", total_time)

if __name__ == "__main__":
    main()

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.1.2
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='C916PXT6XW', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:50:00 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_ARM64_T6031', machine='arm64')
pmocz commented 3 days ago

I am simplifying the code to highlight the error:

import jax
import jax.numpy as jnp
import time

@jax.jit
def get_gradient(f):
    """Calculate the gradients of a field"""

    f_dx = jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)
    f_dy = jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)

    return f_dx, f_dy

@jax.jit
def extrapolate_to_face(f, f_dx, f_dy):
    """Extrapolate the field from face centers to faces using gradients"""

    f_XL = f - f_dx
    f_XL = jnp.roll(f_XL, -1, axis=0)
    f_XR = f + f_dx

    f_YL = f - f_dy
    f_YL = jnp.roll(f_YL, -1, axis=1)
    f_YR = f + f_dy

    return f_XL, f_XR, f_YL, f_YR

@jax.jit
def apply_fluxes(F, flux_F_X, flux_F_Y):
    """Apply fluxes to conserved variables to update solution state"""

    F += -flux_F_X
    F += jnp.roll(flux_F_X, 1, axis=0)
    F += -flux_F_Y
    F += jnp.roll(flux_F_Y, 1, axis=1)

    return F

@jax.jit
def get_flux(A_L, A_R, B_L, B_R):
    """Calculate fluxes between 2 states"""

    A_star = 0.5 * (A_L + A_R)
    B_star = 0.5 * (B_L + B_R)

    flux_A = B_star
    flux_B = B_star**2 / A_star

    flux_A -= 0.1 * (A_L - A_R)
    flux_B -= 0.1 * (B_L - B_R)

    return flux_A, flux_B

# @jax.jit  # <---  XXX Adding this line slows down the code a lot!!
def update(A, B):
    """Take a simulation timestep"""

    A_dx, A_dy = get_gradient(A)
    B_dx, B_dy = get_gradient(B)

    A_XL, A_XR, A_YL, A_YR = extrapolate_to_face(A, A_dx, A_dy)
    B_XL, B_XR, B_YL, B_YR = extrapolate_to_face(B, B_dx, B_dy)

    flux_A_X, flux_B_X = get_flux(A_XL, A_XR, B_XL, B_XR)
    flux_A_Y, flux_B_Y = get_flux(A_YL, A_YR, B_YL, B_YR)

    A = apply_fluxes(A, flux_A_X, flux_A_Y)
    B = apply_fluxes(B, flux_B_X, flux_B_Y)

    return A, B

@jax.jit
def update_compiled_SLOW(A, B):
    return update(A, B)

def main():

    N = 1024

    A = jnp.ones((N, N))
    B = jnp.ones((N, N))
    tic = time.time()
    for _ in range(200):
        (
            A,
            B,
        ) = update(A, B)
    print("Total time not compiled: ", time.time() - tic)

    A = jnp.ones((N, N))
    B = jnp.ones((N, N))
    tic = time.time()
    for _ in range(200):
        A, B = update_compiled_SLOW(A, B)
    print("Total time compiled: ", time.time() - tic)

if __name__ == "__main__":
    main()

gives:

Total time not compiled:  0.6709847450256348
Total time compiled:  2.1534647941589355
jakevdp commented 3 days ago

Thanks for the report! This is definitely unexpected, and points to some compiler issue.

I updated your timing to separate out the first call, use block_until_ready to avoid issues due to asynchronous dispatch, and use IPython's %timeit syntax for better fidelity:

_ = jax.block_until_ready(update(A, B, C))
%timeit jax.block_until_ready(update(A, B, C))

_ = jax.block_until_ready(update_compiled_SLOW(A, B, C))
%timeit jax.block_until_ready(update_compiled_SLOW(A, B, C))

This is the result on Colab CPU:

44.1 ms ± 7.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
165 ms ± 27.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and this is the result on a Colab T4 GPU:

2.72 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.21 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

So it seems this issue is particular to the XLA:CPU compiler. It may be worth reporting this upstream at https://github.com/openxla/xla, though it would be useful to try and reduce the repro even further.

pmocz commented 3 days ago

Thanks for taking a look at this @jakevdp , and pin-pointing that this seems to be a CPU only issue. Definitely unexpected. What is really weird too is that if I comment out some simple terms in the apply_fluxes function like: flux_A -= 0.1 * (A_L - A_R), flux_B -= 0.1 * (B_L - B_R) then the issue goes away

I will make an issue with the XLA team as well

pmocz commented 3 days ago

XLA issue is raised here: https://github.com/openxla/xla/issues/18478