jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

jnp.argsort much slower than the numpy version #10434

Open fbartolic opened 2 years ago

fbartolic commented 2 years ago

Here's a comparison of the JAX and numpy versions of argsort on a CPU:

import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')

key = random.PRNGKey(42)
key, subkey = random.split(key)

x_jnp = random.uniform(subkey, (100, 10000))
x_np = np.array(x_jnp)

%%timeit
np.argsort(x_np, axis=0)

%%timeit
jnp.argsort(x_jnp, axis=0).block_until_ready()

In this case jnp.argsort is ~5X slower than than np.argsort. I'm seeing >20x difference with more realistic arrays. Why is there such a large difference in performance between the two implementations?

jakevdp commented 2 years ago

You might find this FAQ helpful: FAQ: Is JAX Faster Than NumPy?.

fbartolic commented 2 years ago

Thanks! I read the FAQ but I didn't expect that that the difference in performance can get so large.

YouJiacheng commented 2 years ago

@jakevdp It seems that it is a pure computational efficiency problem of sort primitive on CPU. I find that the sort primitive performance on GPU is satisfactory, and sort primitive share the same ~translation rule~ mlir lowering on all platform. Maybe XLA use a parallelism friendly sort algorithm which is inefficient on CPU.

import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')

key = random.PRNGKey(42)
key, subkey = random.split(key)

x_jnp = random.uniform(subkey, (1000000,))
x_np = np.array(x_jnp)

jnp.argsort(x_jnp, axis=0).block_until_ready() # compile
jnp.sort(x_jnp, axis=0).block_until_ready() # compile
from timeit import timeit
print(timeit('np.argsort(x_np, axis=0)', globals=globals(), number=10)) # 1.1s
print(timeit('jnp.argsort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 4.2s
print(timeit('jnp.sort(x_jnp, axis=0).block_until_ready()', globals=globals(), number=10)) # 3.7s
jakevdp commented 2 years ago

Yes, in general the XLA project has put much less effort into optimizing operations on CPU than on other backends.

hawkinsp commented 2 years ago

I also note that the slowness is specific to floating-point values. Sorting int32 values is significantly faster. The only difference between the two as far as I can tell is the comparison function.

sjdv1982 commented 1 year ago

Running into the same issue, I created a workaround, where argsort is run under Numpy if there is only the CPU.

https://gist.github.com/sjdv1982/803695055c78b62e5d5dc92a004efa77

It seems to be compatible with jax.grad, but only after disabling a certain assertion in the JAX code.

I am a beginner in JAX, criticism is welcome, use with care.

jakevdp commented 1 year ago

That's a nice solution! To make it as compatible as possible with JAX transformations, I'd suggest doing the call to numpy via pure_callback instead.

sjdv1982 commented 1 year ago

Thank you! I didn't know pure_callback, I have updated the gist as you suggested. It runs under unmodified JAX now.

I am glad to see that when calling jax.value_and_grads, there are three identical calls into the function, but JAX is smart enough to coalesce them into one.

tvladyslav commented 1 day ago

https://github.com/openxla/xla/pull/18444 should help, please verify.

Upd: reverted due to performance regression.