jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Performance of index_update on GPU with complex double data type #4115

Open markusschmitt opened 4 years ago

markusschmitt commented 4 years ago

Hi, I observe a severe drop of performance of the index_update function on a GPU, when changing the data type from complex single to complex double. Is there a fundamental reason for that or is it something that could be improved?

Here is a simple example to illustrate this issue:

import jax
import jax.numpy as jnp
import numpy as np
import time

from jax.config import config
config.update("jax_enable_x64", True) # Remove this line for single precision

# Set up data
a1=jnp.zeros((1000,256),dtype=np.complex128).block_until_ready()
a2=jax.random.normal(jax.random.PRNGKey(0),(1000,256),dtype=jnp.float64) + 1.j * jax.random.normal(jax.random.PRNGKey(0),(1000,256),dtype=jnp.float64)
idx=jnp.where(jnp.abs(a2)>0)
a3=a2[idx]

# Check data types used
print(a1.dtype)
print(a2.dtype)

# define jit'd index_update function
index_update_jitd=jax.jit(jax.ops.index_update)

# timing including jit
tic=time.perf_counter()
a1=index_update_jitd(a1,idx, a3).block_until_ready()
print("Execution time including jit: ", time.perf_counter() -tic)

# reset data
a1=jnp.zeros((1000,256),dtype=np.complex128).block_until_ready()

# timing of jit'd function
tic=time.perf_counter()
a1=index_update_jitd(a1,idx, a3).block_until_ready()
print("Execution time of jit'd function: ", time.perf_counter() -tic)

I am running this on an NVIDIA V100. Using single precision I get the output

complex64
complex64
Execution time including jit: 0.1222537359863054
Execution time of jit'd function: 0.0004156339855398983

With double precision it is

complex128
complex128
Execution time including jit: 3.707396164012607
Execution time of jit'd function: 3.5751446500071324

You see that for double precision the execution is almost 10000 times slower, which puzzles me.

Also, the execution time seems to be roughly linear in the array size for double precision, whereas on the sizes I checked it did not really change for single precision.

jekbradbury commented 4 years ago

Seems like the same issue as https://github.com/google/jax/issues/3270, but for complex numbers. Maybe the fix wasn't general enough?

hawkinsp commented 3 years ago

4907 added a workaround for the specific case of indexed additions with complex numbers. The non-addition cases are harder to fix and would require XLA to use a strategy other than GPU atomic operators for scatters. complex128 values are too wide for the widest GPU atomic instructions.