google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.79k stars 2.72k forks source link

JAX segment_sum is two times slower for FP16 inputs than FP32 inputs #23136

Open CloudyDory opened 3 weeks ago

CloudyDory commented 3 weeks ago

Description

I find that JAX segment_sum is two times slower for FP16 inputs than FP32 inputs. Here is an example:

import time
import numpy as np
import jax
import jax.numpy as jnp

num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
start_time = time.time()
data_sum = jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)
data_sum.block_until_ready()
print('Run time for FP32: {:.5f} seconds.'.format(time.time()-start_time))

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
start_time = time.time()
data_sum = jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)
data_sum.block_until_ready()
print('Run time for FP16: {:.5f} seconds.'.format(time.time()-start_time))

Outputs:

Run time for FP32: 0.03310 seconds.
Run time for FP16: 0.08621 seconds.

This happens with or without jit(). Why does this happen? And is there a way to optimize the computation for FP16 input?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ZJ', release='6.5.0-45-generic', version='#45~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Jul 15 16:40:02 UTC 2', machine='x86_64')

$ nvidia-smi
Tue Aug 20 10:40:04 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0 Off |                  Off |
|  0%   49C    P8              39W / 480W |     33MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  | 00000000:03:00.0 Off |                  Off |
|  0%   45C    P8              36W / 480W |   5032MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1284      G   /usr/lib/xorg/Xorg                            9MiB |
|    0   N/A  N/A      1317      G   /usr/bin/gnome-shell                         10MiB |
|    1   N/A  N/A      1284      G   /usr/lib/xorg/Xorg                            4MiB |
|    1   N/A  N/A     88489      C   .../miniconda3/envs/jax/bin/python         5012MiB |
+---------------------------------------------------------------------------------------+
dfm commented 3 weeks ago

I don't know the answer to this, but maybe @jakevdp does?

Some notes in the meantime. It's worth checking out the JAX microbenchmark FAQ entry because benchmarking like you're doing here can lead to incorrect conclusions since it includes the tracing and compilation overhead. Updating this doesn't seem to change the specific conclusions though! Here's how I would write the benchmark:

```python import numpy as np import jax import jax.numpy as jnp num_segments = 1700 segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments)) @jax.jit def do_sum(data): return jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True) key = jax.random.PRNGKey(0) data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32) do_sum(data).block_until_ready() # compile %timeit do_sum(data).block_until_ready() key = jax.random.PRNGKey(0) data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16) do_sum(data).block_until_ready() # compile %timeit do_sum(data).block_until_ready() ```

Regardless, I do find that the float16 version is consistently slower. Perhaps @jakevdp can lead us in the right direction!

jakevdp commented 3 weeks ago

Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would be interesting to compare this across different generations of GPU hardware.

CloudyDory commented 3 weeks ago

Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would be interesting to compare this across different generations of GPU hardware.

But I think GPU's FP16 performance shouldn't be slower than FP32 performance. For example, the A100's FP16 FLOPS is twice the FP32 FLOPS; and for NVIDIA 4090, some data shows that it has equal FP16 and FP32 performance.

Is it possible that JAX somehow internally converts FP16 to FP32, performs the computation, and converts the result back to FP16?

jakevdp commented 3 weeks ago

No, I don't think such conversions are happening – you can see exactly what operations the compiler is emitting using ahead of time lowering to output the compiled HLO. This is the output on a T4 GPU:

key = jax.random.key(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype='float16')
print(jax.jit(lambda data: jax.ops.segment_sum(
        data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True
      )).lower(data).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f16[852234]{0})->f16[1700]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="31c88d085a583f79a5c8e16aa07882b5"}

%region_0.7 (Arg_0.8.0: f16[], Arg_1.9.0: f16[]) -> f16[] {
  %Arg_1.9.0 = f16[] parameter(1)
  %Arg_0.8.0 = f16[] parameter(0)
  ROOT %add.1.0 = f16[] add(f16[] %Arg_0.8.0, f16[] %Arg_1.9.0), metadata={op_name="/add" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

%fused_scatter (param_0: f16[1700], param_1: s32[852234,1], param_2.1: f16[852234]) -> f16[1700] {
  %param_0 = f16[1700]{0} parameter(0)
  %param_1 = s32[852234,1]{1,0} parameter(1)
  %param_2.1 = f16[852234]{0} parameter(2)
  %bitcast.26.1 = f16[852234,1]{1,0} bitcast(f16[852234]{0} %param_2.1)
  ROOT %scatter.11.1 = f16[1700]{0} scatter(f16[1700]{0} %param_0, s32[852234,1]{1,0} %param_1, f16[852234,1]{1,0} %bitcast.26.1), update_window_dims={1}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_0.7, metadata={op_name="jit(<lambda>)/jit(main)/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

%fused_broadcast () -> f16[1700] {
  %constant_3_1 = f16[] constant(0)
  ROOT %broadcast.1.1 = f16[1700]{0} broadcast(f16[] %constant_3_1), dimensions={}, metadata={op_name="jit(<lambda>)/jit(main)/broadcast_in_dim[shape=(1700,) broadcast_dimensions=()]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

ENTRY %main.12 (Arg_0.1.0: f16[852234]) -> f16[1700] {
  %constant_1_0 = s32[852234,1]{1,0} constant({...})
  %Arg_0.1.0 = f16[852234]{0} parameter(0)
  %loop_broadcast_fusion = f16[1700]{0} fusion(), kind=kLoop, calls=%fused_broadcast, metadata={op_name="jit(<lambda>)/jit(main)/broadcast_in_dim[shape=(1700,) broadcast_dimensions=()]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
  ROOT %input_scatter_fusion = f16[1700]{0} fusion(f16[1700]{0} %loop_broadcast_fusion, s32[852234,1]{1,0} %constant_1_0, f16[852234]{0} %Arg_0.1.0), kind=kInput, calls=%fused_scatter, metadata={op_name="jit(<lambda>)/jit(main)/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}
CloudyDory commented 3 weeks ago

Thanks for the clarification!

What might be the problem then? I am curious about how we can debug into this issue.

jakevdp commented 3 weeks ago

My best guess still is that the hardware you're using is not optimized for the kinds of operations you're performing (i.e. scatters) in float16, and is more optimized for float32.

Appendix A of https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf suggests that for GeForce RTX 4090, non-tensor ops are no faster in F16 than in F32, though it doesn't indicate that they should be slower. It may be that performance is worse for F16 scatters – I'm not sure.

CloudyDory commented 3 weeks ago

I write a similar benchmark on PyTorch 2.3.1 and the torch_scatter library, and now I agree that non-tensor ops are no faster in F16 than in F32 on GeForce RTX 4090.

However, it seems that PyTorch's FP16 performance is 380 times faster than Jax's FP16 performance on RTX 4090. If the following benchmark code is correct, then there is still much room for improvement in Jax?

PyTorch code:

import torch
import torch_scatter  # !conda install pytorch-scatter -c pyg

def do_sum(data):
    y = torch_scatter.scatter(data, segment_ids, reduce='sum')
    torch.cuda.synchronize()
    return y

device = torch.device('cuda')
num_segments = 1700
segment_ids = torch.repeat_interleave(torch.arange(num_segments, device=device), torch.randint(40,977,size=(num_segments,),device=device))

# FP32
data_fp32 = torch.rand(len(segment_ids), dtype=torch.float32, device=device)
torch.cuda.synchronize()
%timeit do_sum(data_fp32)

# FP16
data_fp16 = torch.rand(len(segment_ids), dtype=torch.float16, device=device)
torch.cuda.synchronize()
%timeit do_sum(data_fp16)

PyTorch result:

102 μs ± 75.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
120 μs ± 121 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Jax code:

import numpy as np
import jax
import jax.numpy as jnp

num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))

@jax.jit
def do_sum(data):
  return jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)

key = jax.random.PRNGKey(0)
data_fp32 = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
do_sum(data_fp32).block_until_ready()  # compile
%timeit do_sum(data_fp32).block_until_ready()

key = jax.random.PRNGKey(0)
data_fp16 = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
do_sum(data_fp16).block_until_ready()  # compile
%timeit do_sum(data_fp16).block_until_ready()

Jax result:

161 μs ± 2.84 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
45.8 ms ± 162 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
jakevdp commented 3 weeks ago

I think these are not equivalent operations – wouldn't torch.scatter be equivalent to JAX scatter, not JAX segment sum?

CloudyDory commented 3 weeks ago

Hi, this is not torch.scatter, but torch_scatter.scatter in the torch_scatter library (https://github.com/rusty1s/pytorch_scatter).

According to the documentation (https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter), torch_scatter.scatter is actually doing segment_sum. image

We can also verify this in the following code: Pytorch:

import torch
import torch_scatter

data = torch.tensor([3.12, 4.98, 5.0, -1.3, -0.45, 2.08, 1.2], dtype=torch.float32)
segment_ids = torch.tensor([1,0,0,1,2,2,3], dtype=torch.int64)
print(torch_scatter.scatter(data, segment_ids, reduce='sum'))

Output:

tensor([9.9800, 1.8200, 1.6300, 1.2000])

JAX:

import jax
import jax.numpy as jnp

data = jnp.array([3.12, 4.98, 5.0, -1.3, -0.45, 2.08, 1.2], dtype=jnp.float32)
segment_ids = jnp.array([1,0,0,1,2,2,3], dtype=jnp.int32)
print(jax.ops.segment_sum(data, segment_ids=segment_ids))

Output:

[9.98      1.8199999 1.6299999 1.2      ]
jakevdp commented 3 weeks ago

Ah, thanks for the clarification. Looks like it is doing the same thing – I'm not sure why JAX's version is slower.