jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

array.at.set is incredibly slow for complex128 dtype #24872

Open chrisrothUT opened 2 days ago

chrisrothUT commented 2 days ago

Description

For some reason array.at.set is incredibly slow with complex128 datatypes. Here I show it is much faster to split the arrays into real and imaginary parts before calling array.at.set and then recombine them into a complex array afterwards.

from jax import numpy as jnp
from time import time
import jax
import os

jax.config.update("jax_enable_x64", True)
@jax.jit
def set(x,x2,inds):
  return x.at[inds].set(x2)

@jax.jit
def complex_set(x,x2,inds):
  return jax.lax.complex(x.real.at[inds].set(x2.real), x.imag.at[inds].set(x2.imag))

x = jnp.zeros([10000000],dtype=jnp.complex64)
x2 = jnp.zeros([10000],dtype=jnp.complex64)
inds = jnp.arange(10000)

set(x,x2,inds)
complex_set(x,x2,inds)

t = time()
jax.block_until_ready(set(x,x2,inds))
print('set time=', time()-t)

t = time()
jax.block_until_ready(complex_set(x,x2,inds))
print('complex set time=', time()-t)

set time= 0.07047343254089355 complex set time= 0.0006287097930908203

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35 jaxlib: 0.4.34 numpy: 2.0.2 python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0] device info: NVIDIA H100 PCIe-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='workergpu158', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')

hawkinsp commented 2 days ago

As it happens we have workaround in JAX to avoid this slow behavior for scatter-add and scatter-sub, but not scatter-update. It should be pretty easy to make it work for scatter-update as well.

(The issue is that 128-bit scatters are currently expensive in XLA, because NVIDIA GPUs don't have a 16-byte atomic write operation.)

hawkinsp commented 1 day ago

Actually, thinking about this a bit more, it's somewhat problematic to split into real and imaginary parts.

If there are multiple updates to the same index, then it's unspecified which update "wins". If we performed updates to both real and imaginary parts separately, you might get the real part of one and the imaginary part of another. Only if you promised us the indices are non-overlapping would it be safe for us to do that. Is that true in your case?

It's easier for add and sub because those are associative operations; we can apply the updates in any order and still get the same result, up to floating point error.

chrisrothUT commented 1 day ago

I see the issue. Yes, in our case the indices are non-overlapping so these functions are strictly the same.

Maybe the solution is to provide a warning about how scatter-update is slow with complex128 dtype and suggest updating the real and imaginary parts separately?