Hi,
I observe a severe drop of performance of the index_update function on a GPU, when changing the data type from complex single to complex double. Is there a fundamental reason for that or is it something that could be improved?
Here is a simple example to illustrate this issue:
import jax
import jax.numpy as jnp
import numpy as np
import time
from jax.config import config
config.update("jax_enable_x64", True) # Remove this line for single precision
# Set up data
a1=jnp.zeros((1000,256),dtype=np.complex128).block_until_ready()
a2=jax.random.normal(jax.random.PRNGKey(0),(1000,256),dtype=jnp.float64) + 1.j * jax.random.normal(jax.random.PRNGKey(0),(1000,256),dtype=jnp.float64)
idx=jnp.where(jnp.abs(a2)>0)
a3=a2[idx]
# Check data types used
print(a1.dtype)
print(a2.dtype)
# define jit'd index_update function
index_update_jitd=jax.jit(jax.ops.index_update)
# timing including jit
tic=time.perf_counter()
a1=index_update_jitd(a1,idx, a3).block_until_ready()
print("Execution time including jit: ", time.perf_counter() -tic)
# reset data
a1=jnp.zeros((1000,256),dtype=np.complex128).block_until_ready()
# timing of jit'd function
tic=time.perf_counter()
a1=index_update_jitd(a1,idx, a3).block_until_ready()
print("Execution time of jit'd function: ", time.perf_counter() -tic)
I am running this on an NVIDIA V100. Using single precision I get the output
complex64
complex64
Execution time including jit: 0.1222537359863054
Execution time of jit'd function: 0.0004156339855398983
With double precision it is
complex128
complex128
Execution time including jit: 3.707396164012607
Execution time of jit'd function: 3.5751446500071324
You see that for double precision the execution is almost 10000 times slower, which puzzles me.
Also, the execution time seems to be roughly linear in the array size for double precision, whereas on the sizes I checked it did not really change for single precision.
4907 added a workaround for the specific case of indexed additions with complex numbers. The non-addition cases are harder to fix and would require XLA to use a strategy other than GPU atomic operators for scatters. complex128 values are too wide for the widest GPU atomic instructions.
Hi, I observe a severe drop of performance of the index_update function on a GPU, when changing the data type from complex single to complex double. Is there a fundamental reason for that or is it something that could be improved?
Here is a simple example to illustrate this issue:
I am running this on an NVIDIA V100. Using single precision I get the output
With double precision it is
You see that for double precision the execution is almost 10000 times slower, which puzzles me.
Also, the execution time seems to be roughly linear in the array size for double precision, whereas on the sizes I checked it did not really change for single precision.