Open fbartolic opened 2 years ago
You might find this FAQ helpful: FAQ: Is JAX Faster Than NumPy?.
Thanks! I read the FAQ but I didn't expect that that the difference in performance can get so large.
@jakevdp It seems that it is a pure computational efficiency problem of sort
primitive on CPU.
I find that the sort
primitive performance on GPU is satisfactory, and sort
primitive share the same ~translation rule~ mlir lowering on all platform. Maybe XLA use a parallelism friendly sort algorithm which is inefficient on CPU.
import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')
key = random.PRNGKey(42)
key, subkey = random.split(key)
x_jnp = random.uniform(subkey, (1000000,))
x_np = np.array(x_jnp)
jnp.argsort(x_jnp, axis=0).block_until_ready() # compile
jnp.sort(x_jnp, axis=0).block_until_ready() # compile
from timeit import timeit
print(timeit('np.argsort(x_np, axis=0)', globals=globals(), number=10)) # 1.1s
print(timeit('jnp.argsort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 4.2s
print(timeit('jnp.sort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 3.7s
Yes, in general the XLA project has put much less effort into optimizing operations on CPU than on other backends.
I also note that the slowness is specific to floating-point values. Sorting int32 values is significantly faster. The only difference between the two as far as I can tell is the comparison function.
Running into the same issue, I created a workaround, where argsort
is run under Numpy if there is only the CPU.
https://gist.github.com/sjdv1982/803695055c78b62e5d5dc92a004efa77
It seems to be compatible with jax.grad
, but only after disabling a certain assertion in the JAX code.
I am a beginner in JAX, criticism is welcome, use with care.
That's a nice solution! To make it as compatible as possible with JAX transformations, I'd suggest doing the call to numpy via pure_callback
instead.
Thank you! I didn't know pure_callback
, I have updated the gist as you suggested. It runs under unmodified JAX now.
I am glad to see that when calling jax.value_and_grads
, there are three identical calls into the function, but JAX is smart enough to coalesce them into one.
https://github.com/openxla/xla/pull/18444 should help, please verify.
Upd: reverted due to performance regression.
Here's a comparison of the JAX and numpy versions of
argsort
on a CPU:In this case
jnp.argsort
is ~5X slower than thannp.argsort
. I'm seeing >20x difference with more realistic arrays. Why is there such a large difference in performance between the two implementations?