Closed YigitElma closed 1 month ago
The issue here has to do with where the jax.jit
is applied. When defining fun
, the jit
decorator is currently inside the custom_jvp
. This means that the full computation isn't being JIT compiled. To fix this issue, you can update the code as follows:
# JIT compile `fun` to remove all the overhead of `custom_jvp`
fun = jax.jit(fun)
# It is useful to warm up the cache
fun(x, n, dx).block_until_ready()
fun2(x, n, dx).block_until_ready()
%timeit fun(x, n, dx).block_until_ready()
%timeit fun2(x, n, dx).block_until_ready()
On my system I now see roughly the same performance.
Note: For somewhat subtle reasons you can't actually just use:
@jax.jit
@jax.custom_jvp
def fun(x, n, dx):
...
You need to apply the jit
after defining the JVP function.
Hope this helps!
Thank you very much @dfm! I had also tried to put jax.jit
before custom_jvp
but I got the following error,
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], [line 13](vscode-notebook-cell:?execution_count=2&line=13)
[10](vscode-notebook-cell:?execution_count=2&line=10) coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
[11](vscode-notebook-cell:?execution_count=2&line=11) return coef*out
---> [13](vscode-notebook-cell:?execution_count=2&line=13) @fun.defjvp
[14](vscode-notebook-cell:?execution_count=2&line=14) def _fun_jvp(a, adot):
[15](vscode-notebook-cell:?execution_count=2&line=15) """Custom derivative rule for the function.
[16](vscode-notebook-cell:?execution_count=2&line=16)
[17](vscode-notebook-cell:?execution_count=2&line=17) This is just the same function called with dx+1.
[18](vscode-notebook-cell:?execution_count=2&line=18) """
[19](vscode-notebook-cell:?execution_count=2&line=19) (x, n, dx) = a
AttributeError: 'jaxlib.xla_extension.PjitFunction' object has no attribute 'defjvp'
I can apply jax.jit
in my project probably as you suggest. But it would be perfect to be able to change the order and add jit as decorator instead of jitting manually.
Great! I'd say that supporting the other order of decorators is pretty low priority since custom_jvp
functions typically don't live at the top level of a program. In other words, normally you call custom_jvp
functions deep within some larger function. You'll typically get the most performance improvement by jit
ting at the outermost scope. Therefore, I don't think we're likely to add support for calling defjvp
on a jit
function. For your use case, I don't think it adds too much overhead to add that one line!
I'm going to close this now because I think we solved the main issue. As I mentioned, I think the feature request is unlikely to be executed, but please feel free to open a new feature request with more details if it becomes a blocker.
@dfm My actual function uses static_argnums
with @functools.partial(jit, static_argnums=3)
. For some reason, that caused a problem with jit
ing the function by
fun = jax.jit(fun, static_argnums=3)
Instead I had to,
fun = jax.jit(fun.fun, static_argnums=3)
Is there something I am missing?
What was the issue you were seeing? I believe that jax.jit(fun, static_argnums=3)
should work fine, and jax.jit(fun.fun, static_argnums=3)
definitely isn't what you want because then you'll lose the custom_jvp
!
Can you post a minimal reproducer and the specific error you're seeing?
Let's say I added a redundant if
statement, and now I need to specify static_argnums
(my actual code is a bit more complex, but have similar problem).
import jax
import jax.numpy as jnp
from jax.lax import fori_loop
import functools
@jax.custom_jvp
@functools.partial(jax.jit, static_argnums=2)
def fun(x, n, dx):
"""The dx th derivative of the function x^n."""
if dx == 0:
return x**n
out = x**(n-dx)
coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
return coef*out
@fun.defjvp
def _fun_jvp(a, adot):
"""Custom derivative rule for the function.
This is just the same function called with dx+1.
"""
(x, n, dx) = a
(xdot, ndot, dxdot) = adot
f = fun(x, n, dx)
df = fun(x, n, dx+1)
return f, df
@functools.partial(jax.jit, static_argnums=2)
def fun2(x, n, dx):
"""Same function without custom_jvp decorator."""
if dx == 0:
return x**n
out = x**(n-dx)
coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
return coef*out
x = jnp.arange(1000)
n = 10
dx = 2
fun = jax.jit(fun, static_argnums=2)
# assert fun(x, n, dx).all() == fun2(x, n, dx).all()
%timeit fun(x, n, dx).block_until_ready()
%timeit fun2(x, n, dx).block_until_ready()
The error message now is
[... skipping hidden 28 frame]
Cell In[8], [line 10](vscode-notebook-cell:?execution_count=8&line=10)
[6](vscode-notebook-cell:?execution_count=8&line=6) @jax.custom_jvp
[7](vscode-notebook-cell:?execution_count=8&line=7) @functools.partial(jax.jit, static_argnums=2)
[8](vscode-notebook-cell:?execution_count=8&line=8) def fun(x, n, dx):
[9](vscode-notebook-cell:?execution_count=8&line=9) """The dx th derivative of the function x^n."""
---> [10](vscode-notebook-cell:?execution_count=8&line=10) if dx == 0:
[11](vscode-notebook-cell:?execution_count=8&line=11) return x**n
[12](vscode-notebook-cell:?execution_count=8&line=12) out = x**(n-dx)
[... skipping hidden 1 frame]
File ~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
[1474](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUG/~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1474) def error(self, arg):
-> [1475](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUG/~/miniconda3/envs/desc-env-11/lib/python3.11/site-packages/jax/_src/core.py:1475) raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function fun at /tmp/ipykernel_22677/578150583.py:6 for jit. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = eq b c
from line /tmp/ipykernel_22677/578150583.py:10:7 (fun)
And, yes, I realized that I was losing custom_jvp
right after I sent the comment :)
Ah so in this case the JIT isn't the problem. You need to label that parameter also as non-differentiable:
@functools.partial(jax.custom_jvp, nondiff_argnums=(2,))
def fun(x, n, dx):
...
You could keep the inner jit, but I don't think it adds any value.
Oh thank you very much! It works now!
For practicality, I still want to make a decorator that jit
s and custom_jvp
s the function. Here is my version,
import jax
import jax.numpy as jnp
from jax.lax import fori_loop
import functools
def custom_jvp_with_jit(func):
@functools.partial(jax.custom_jvp, nondiff_argnums=(2,))
def dummy(x, n, dx):
return func(x, n, dx)
@dummy.defjvp
def _dummy_jvp(nondiff_dx, a, adot):
"""Custom derivative rule for the function.
This is just the same function called with dx+1.
"""
(x, n) = a
(xdot, ndot) = adot
f = dummy(x, n, nondiff_dx)
df = dummy(x, n, nondiff_dx+1)
return f, df
return jax.jit(dummy, static_argnums=2)
@custom_jvp_with_jit
def fun(x, n, dx):
"""The dx th derivative of the function x^n."""
if dx == 0:
return x**n
out = x**(n-dx)
coef = fori_loop(0, dx, lambda i, x: x*(n-i), 1)
return coef*out
I hope it doesn't have any more problems!!!
EDIT: I had to change .defjvp
to comply with this tutorial.
Description
I want to define a function such that one of the input arguments is the derivative order that the user wants. My actual function is far more complex but here is an example that has the same issue.
This function calculates
dx
th derivative of $x^n$. So, $$dx = 0: f(x)=x^n $$ $$dx = 1: f(x)=nx^{n-1} $$ $$dx = 2: f(x)=n(n-1)x^{n-2}$$I don't want to use
grad()
to take the derivative. However, when I define the custom derivative to be the same function called withdx+1
, the execution becomes way slower (even if I call the original functiondx
=0). I assume this is due tocustom_jvp
being a class and this adds some extra stuff being done when that function is called, even if we call it multiple times.The speed is the key for my application, but I also need to define a custom derivative rule for the function for jacobian calculations involving this function. Is there a more optimized way to do that?
System info (python version, jaxlib version, accelerator, etc.)