Open chrisrothUT opened 2 days ago
As it happens we have workaround in JAX to avoid this slow behavior for scatter-add and scatter-sub, but not scatter-update. It should be pretty easy to make it work for scatter-update as well.
(The issue is that 128-bit scatters are currently expensive in XLA, because NVIDIA GPUs don't have a 16-byte atomic write operation.)
Actually, thinking about this a bit more, it's somewhat problematic to split into real and imaginary parts.
If there are multiple updates to the same index, then it's unspecified which update "wins". If we performed updates to both real and imaginary parts separately, you might get the real part of one and the imaginary part of another. Only if you promised us the indices are non-overlapping would it be safe for us to do that. Is that true in your case?
It's easier for add
and sub
because those are associative operations; we can apply the updates in any order and still get the same result, up to floating point error.
I see the issue. Yes, in our case the indices are non-overlapping so these functions are strictly the same.
Maybe the solution is to provide a warning about how scatter-update is slow with complex128 dtype and suggest updating the real and imaginary parts separately?
Description
For some reason
array.at.set
is incredibly slow with complex128 datatypes. Here I show it is much faster to split the arrays into real and imaginary parts before callingarray.at.set
and then recombine them into a complex array afterwards.set time= 0.07047343254089355 complex set time= 0.0006287097930908203
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35 jaxlib: 0.4.34 numpy: 2.0.2 python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0] device info: NVIDIA H100 PCIe-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='workergpu158', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')