google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.74k stars 2.71k forks source link

`custom_jvp` functions are slow when used with `jit` #22513

Closed YigitElma closed 1 month ago

YigitElma commented 1 month ago

Description

I want to define a function such that one of the input arguments is the derivative order that the user wants. My actual function is far more complex but here is an example that has the same issue.

import jax
import jax.numpy as jnp
from jax.lax import fori_loop

@jax.custom_jvp
@jax.jit
def fun(x, n, dx):
    """The dx th derivative of the function x^n."""
    out = x**(n-dx)
    coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
    return coef*out

@fun.defjvp
def _fun_jvp(a, adot):
    """Custom derivative rule for the function.

    This is just the same function called with dx+1.
    """
    (x, n, dx) = a
    (xdot, ndot, dxdot) = adot
    f = fun(x, n, dx)
    df = fun(x, n, dx+1)
    return f, df

@jax.jit
def fun2(x, n, dx):
    """Same function without custom_jvp decorator."""
    out = x**(n-dx)
    coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
    return coef*out

x = jnp.arange(1000)
n = 10
dx = 0

assert fun(x, n, dx).all() == fun2(x, n, dx).all()

%timeit fun(x, n, dx).block_until_ready()
%timeit fun2(x, n, dx).block_until_ready()
73 μs ± 1.05 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
8.6 μs ± 1.26 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

This function calculates dxth derivative of $x^n$. So, $$dx = 0: f(x)=x^n $$ $$dx = 1: f(x)=nx^{n-1} $$ $$dx = 2: f(x)=n(n-1)x^{n-2}$$

I don't want to use grad() to take the derivative. However, when I define the custom derivative to be the same function called with dx+1, the execution becomes way slower (even if I call the original function dx=0). I assume this is due to custom_jvp being a class and this adds some extra stuff being done when that function is called, even if we call it multiple times.

The speed is the key for my application, but I also need to define a custom derivative rule for the function for jacobian calculations involving this function. Is there a more optimized way to do that?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.0 (main, Mar  1 2023, 18:26:19) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='Yigit', release='6.5.0-1025-oem', version='#26-Ubuntu SMP PREEMPT_DYNAMIC Tue Jun 18 12:35:22 UTC 2024', machine='x86_64')
dfm commented 1 month ago

The issue here has to do with where the jax.jit is applied. When defining fun, the jit decorator is currently inside the custom_jvp. This means that the full computation isn't being JIT compiled. To fix this issue, you can update the code as follows:

# JIT compile `fun` to remove all the overhead of `custom_jvp`
fun = jax.jit(fun)

# It is useful to warm up the cache
fun(x, n, dx).block_until_ready()
fun2(x, n, dx).block_until_ready()

%timeit fun(x, n, dx).block_until_ready()
%timeit fun2(x, n, dx).block_until_ready()

On my system I now see roughly the same performance.

Note: For somewhat subtle reasons you can't actually just use:

@jax.jit
@jax.custom_jvp
def fun(x, n, dx):
  ...

You need to apply the jit after defining the JVP function.

Hope this helps!

YigitElma commented 1 month ago

Thank you very much @dfm! I had also tried to put jax.jit before custom_jvp but I got the following error,

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], [line 13](vscode-notebook-cell:?execution_count=2&line=13)
     [10](vscode-notebook-cell:?execution_count=2&line=10)     coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
     [11](vscode-notebook-cell:?execution_count=2&line=11)     return coef*out
---> [13](vscode-notebook-cell:?execution_count=2&line=13) @fun.defjvp
     [14](vscode-notebook-cell:?execution_count=2&line=14) def _fun_jvp(a, adot):
     [15](vscode-notebook-cell:?execution_count=2&line=15)     """Custom derivative rule for the function.
     [16](vscode-notebook-cell:?execution_count=2&line=16)     
     [17](vscode-notebook-cell:?execution_count=2&line=17)     This is just the same function called with dx+1.
     [18](vscode-notebook-cell:?execution_count=2&line=18)     """
     [19](vscode-notebook-cell:?execution_count=2&line=19)     (x, n, dx) = a

AttributeError: 'jaxlib.xla_extension.PjitFunction' object has no attribute 'defjvp'

I can apply jax.jit in my project probably as you suggest. But it would be perfect to be able to change the order and add jit as decorator instead of jitting manually.

dfm commented 1 month ago

Great! I'd say that supporting the other order of decorators is pretty low priority since custom_jvp functions typically don't live at the top level of a program. In other words, normally you call custom_jvp functions deep within some larger function. You'll typically get the most performance improvement by jitting at the outermost scope. Therefore, I don't think we're likely to add support for calling defjvp on a jit function. For your use case, I don't think it adds too much overhead to add that one line!

dfm commented 1 month ago

I'm going to close this now because I think we solved the main issue. As I mentioned, I think the feature request is unlikely to be executed, but please feel free to open a new feature request with more details if it becomes a blocker.

YigitElma commented 1 month ago

@dfm My actual function uses static_argnums with @functools.partial(jit, static_argnums=3). For some reason, that caused a problem with jiting the function by

fun = jax.jit(fun, static_argnums=3)

Instead I had to,

fun = jax.jit(fun.fun, static_argnums=3)

Is there something I am missing?

dfm commented 1 month ago

What was the issue you were seeing? I believe that jax.jit(fun, static_argnums=3) should work fine, and jax.jit(fun.fun, static_argnums=3) definitely isn't what you want because then you'll lose the custom_jvp!

Can you post a minimal reproducer and the specific error you're seeing?

YigitElma commented 1 month ago

Let's say I added a redundant if statement, and now I need to specify static_argnums (my actual code is a bit more complex, but have similar problem).

import jax
import jax.numpy as jnp
from jax.lax import fori_loop
import functools

@jax.custom_jvp
@functools.partial(jax.jit, static_argnums=2)
def fun(x, n, dx):
    """The dx th derivative of the function x^n."""
    if dx == 0:
        return x**n
    out = x**(n-dx)
    coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
    return coef*out

@fun.defjvp
def _fun_jvp(a, adot):
    """Custom derivative rule for the function.

    This is just the same function called with dx+1.
    """
    (x, n, dx) = a
    (xdot, ndot, dxdot) = adot
    f = fun(x, n, dx)
    df = fun(x, n, dx+1)
    return f, df

@functools.partial(jax.jit, static_argnums=2)
def fun2(x, n, dx):
    """Same function without custom_jvp decorator."""
    if dx == 0:
        return x**n
    out = x**(n-dx)
    coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
    return coef*out

x = jnp.arange(1000)
n = 10
dx = 2

fun = jax.jit(fun, static_argnums=2)

# assert fun(x, n, dx).all() == fun2(x, n, dx).all()

%timeit fun(x, n, dx).block_until_ready()
%timeit fun2(x, n, dx).block_until_ready()

The error message now is

[... skipping hidden 28 frame]

Cell In[8], [line 10](vscode-notebook-cell:?execution_count=8&line=10)
      [6](vscode-notebook-cell:?execution_count=8&line=6) @jax.custom_jvp
      [7](vscode-notebook-cell:?execution_count=8&line=7) @functools.partial(jax.jit, static_argnums=2)
      [8](vscode-notebook-cell:?execution_count=8&line=8) def fun(x, n, dx):
      [9](vscode-notebook-cell:?execution_count=8&line=9)     """The dx th derivative of the function x^n."""
---> [10](vscode-notebook-cell:?execution_count=8&line=10)     if dx == 0:
     [11](vscode-notebook-cell:?execution_count=8&line=11)         return x**n
     [12](vscode-notebook-cell:?execution_count=8&line=12)     out = x**(n-dx)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   [1474](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUG/~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1474) def error(self, arg):
-> [1475](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUG/~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1475)   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function fun at /tmp/ipykernel_22677/578150583.py:6 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = eq b c
    from line /tmp/ipykernel_22677/578150583.py:10:7 (fun)

And, yes, I realized that I was losing custom_jvp right after I sent the comment :)

dfm commented 1 month ago

Ah so in this case the JIT isn't the problem. You need to label that parameter also as non-differentiable:

@functools.partial(jax.custom_jvp, nondiff_argnums=(2,))
def fun(x, n, dx):
  ...

You could keep the inner jit, but I don't think it adds any value.

YigitElma commented 1 month ago

Oh thank you very much! It works now!

For practicality, I still want to make a decorator that jits and custom_jvps the function. Here is my version,

import jax
import jax.numpy as jnp
from jax.lax import fori_loop
import functools

def custom_jvp_with_jit(func):
    @functools.partial(jax.custom_jvp, nondiff_argnums=(2,))
    def dummy(x, n, dx):
        return func(x, n, dx)

    @dummy.defjvp
    def _dummy_jvp(nondiff_dx, a, adot):
        """Custom derivative rule for the function.

        This is just the same function called with dx+1.
        """
        (x, n) = a
        (xdot, ndot) = adot
        f = dummy(x, n, nondiff_dx)
        df = dummy(x, n, nondiff_dx+1)
        return f, df

    return jax.jit(dummy, static_argnums=2)

@custom_jvp_with_jit
def fun(x, n, dx):
    """The dx th derivative of the function x^n."""
    if dx == 0:
        return x**n
    out = x**(n-dx)
    coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
    return coef*out

I hope it doesn't have any more problems!!!

EDIT: I had to change .defjvp to comply with this tutorial.