google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.62k stars 2.7k forks source link

BF16 matmul slower than F32 matmul on T4 GPU #21212

Closed sagelywizard closed 1 week ago

sagelywizard commented 3 months ago

Description

BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.

You can repro in a T4 colab with the following:

import jax
import jax.numpy as jnp
import timeit

def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
  key = jax.random.PRNGKey(0)
  x_i = 2**exponent
  x_j = 4096
  y_j = 4096
  flop_count = x_i * x_j * y_j * 2
  x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
  y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
  matmul = jax.jit(lambda a, b: a @ b)
  matmul(x, y).block_until_ready()
  seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
  flops = flop_count / seconds_per_iter
  return flop_count, flops

def flops_to_tflops(flops):
  return flops / 1e12

for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
  print(dtype)
  for i in range(16):
    op_count, flops = flops_calc(exponent=i, dtype=dtype)
    print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
  print()

This results in the following output:

<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.04
Total TFLOP Count: 0.00013 | TFLOPS: 0.09
Total TFLOP Count: 0.00027 | TFLOPS: 0.16
Total TFLOP Count: 0.00054 | TFLOPS: 0.35
Total TFLOP Count: 0.00107 | TFLOPS: 0.61
Total TFLOP Count: 0.00215 | TFLOPS: 1.09
Total TFLOP Count: 0.00429 | TFLOPS: 1.22
Total TFLOP Count: 0.00859 | TFLOPS: 1.74
Total TFLOP Count: 0.01718 | TFLOPS: 2.27
Total TFLOP Count: 0.03436 | TFLOPS: 2.36
Total TFLOP Count: 0.06872 | TFLOPS: 2.36
Total TFLOP Count: 0.13744 | TFLOPS: 2.16
Total TFLOP Count: 0.27488 | TFLOPS: 2.19
Total TFLOP Count: 0.54976 | TFLOPS: 2.14
Total TFLOP Count: 1.09951 | TFLOPS: 2.09

<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.92
Total TFLOP Count: 0.00054 | TFLOPS: 1.76
Total TFLOP Count: 0.00107 | TFLOPS: 3.53
Total TFLOP Count: 0.00215 | TFLOPS: 6.99
Total TFLOP Count: 0.00429 | TFLOPS: 14.04
Total TFLOP Count: 0.00859 | TFLOPS: 23.47
Total TFLOP Count: 0.01718 | TFLOPS: 25.02
Total TFLOP Count: 0.03436 | TFLOPS: 35.24
Total TFLOP Count: 0.06872 | TFLOPS: 37.16
Total TFLOP Count: 0.13744 | TFLOPS: 31.20
Total TFLOP Count: 0.27488 | TFLOPS: 24.41
Total TFLOP Count: 0.54976 | TFLOPS: 23.02
Total TFLOP Count: 1.09951 | TFLOPS: 22.13

<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.16
Total TFLOP Count: 0.00013 | TFLOPS: 0.31
Total TFLOP Count: 0.00027 | TFLOPS: 0.66
Total TFLOP Count: 0.00054 | TFLOPS: 1.34
Total TFLOP Count: 0.00107 | TFLOPS: 2.61
Total TFLOP Count: 0.00215 | TFLOPS: 4.18
Total TFLOP Count: 0.00429 | TFLOPS: 4.92
Total TFLOP Count: 0.00859 | TFLOPS: 5.32
Total TFLOP Count: 0.01718 | TFLOPS: 4.59
Total TFLOP Count: 0.03436 | TFLOPS: 4.31
Total TFLOP Count: 0.06872 | TFLOPS: 4.19
Total TFLOP Count: 0.13744 | TFLOPS: 4.04
Total TFLOP Count: 0.27488 | TFLOPS: 4.30
Total TFLOP Count: 0.54976 | TFLOPS: 4.31
Total TFLOP Count: 1.09951 | TFLOPS: 4.37

Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='063d876e5268', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon May 13 18:04:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0              30W /  70W |  11457MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
jakevdp commented 3 months ago

Thanks for the report!

Here's a ref to the T4 architecture spec: https://images.nvidia.com/aem-dam/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf

T4 doesn't support bfloat16, but JAX (via the XLA GPU compiler) should be falling back to float32. The fact that the result is appreciably slower than native float32 may indicate a bug in the XLA GPU compiler.

I'd suggest reporting at https://github.com/openxla/xla

cheshire commented 3 months ago

Replied on the OpenXLA bug.

rajasekharporeddy commented 2 months ago

Hi @sagelywizard

I tested mentioned code on Google Colab T4 GPU with JAX version 0.4.26, 0.4.27 and later. From the JAX version 0.4.27, the speed of BF16 matmul is almost similar to that of F32 matmul on T4 GPU. I am getting the following output when tested with JAX 0.4.27:

import jax
import jax.numpy as jnp
import timeit

print(jax.__version__)

def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
  key = jax.random.PRNGKey(0)
  x_i = 2**exponent
  x_j = 4096
  y_j = 4096
  flop_count = x_i * x_j * y_j * 2
  x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
  y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
  matmul = jax.jit(lambda a, b: a @ b)
  matmul(x, y).block_until_ready()
  seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
  flops = flop_count / seconds_per_iter
  return flop_count, flops

def flops_to_tflops(flops):
  return flops / 1e12

for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
  print(dtype)
  for i in range(16):
    op_count, flops = flops_calc(exponent=i, dtype=dtype)
    print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
  print()

Output:

0.4.27
<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.06
Total TFLOP Count: 0.00013 | TFLOPS: 0.13
Total TFLOP Count: 0.00027 | TFLOPS: 0.23
Total TFLOP Count: 0.00054 | TFLOPS: 0.46
Total TFLOP Count: 0.00107 | TFLOPS: 0.88
Total TFLOP Count: 0.00215 | TFLOPS: 1.47
Total TFLOP Count: 0.00429 | TFLOPS: 1.95
Total TFLOP Count: 0.00859 | TFLOPS: 4.33
Total TFLOP Count: 0.01718 | TFLOPS: 5.20
Total TFLOP Count: 0.03436 | TFLOPS: 5.17
Total TFLOP Count: 0.06872 | TFLOPS: 4.74
Total TFLOP Count: 0.13744 | TFLOPS: 4.69
Total TFLOP Count: 0.27488 | TFLOPS: 4.98
Total TFLOP Count: 0.54976 | TFLOPS: 4.96
Total TFLOP Count: 1.09951 | TFLOPS: 4.89

<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.77
Total TFLOP Count: 0.00054 | TFLOPS: 1.66
Total TFLOP Count: 0.00107 | TFLOPS: 3.32
Total TFLOP Count: 0.00215 | TFLOPS: 6.68
Total TFLOP Count: 0.00429 | TFLOPS: 10.77
Total TFLOP Count: 0.00859 | TFLOPS: 14.33
Total TFLOP Count: 0.01718 | TFLOPS: 22.38
Total TFLOP Count: 0.03436 | TFLOPS: 29.81
Total TFLOP Count: 0.06872 | TFLOPS: 31.90
Total TFLOP Count: 0.13744 | TFLOPS: 27.16
Total TFLOP Count: 0.27488 | TFLOPS: 24.00
Total TFLOP Count: 0.54976 | TFLOPS: 23.06
Total TFLOP Count: 1.09951 | TFLOPS: 24.39

<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.15
Total TFLOP Count: 0.00013 | TFLOPS: 0.32
Total TFLOP Count: 0.00027 | TFLOPS: 0.59
Total TFLOP Count: 0.00054 | TFLOPS: 1.14
Total TFLOP Count: 0.00107 | TFLOPS: 1.53
Total TFLOP Count: 0.00215 | TFLOPS: 2.22
Total TFLOP Count: 0.00429 | TFLOPS: 2.44
Total TFLOP Count: 0.00859 | TFLOPS: 2.64
Total TFLOP Count: 0.01718 | TFLOPS: 2.74
Total TFLOP Count: 0.03436 | TFLOPS: 2.94
Total TFLOP Count: 0.06872 | TFLOPS: 3.88
Total TFLOP Count: 0.13744 | TFLOPS: 4.17
Total TFLOP Count: 0.27488 | TFLOPS: 4.54
Total TFLOP Count: 0.54976 | TFLOPS: 4.46
Total TFLOP Count: 1.09951 | TFLOPS: 4.51

Please find the gist for reference.

Thank you.

sagelywizard commented 1 week ago

Perfect, thanks! I was able to repro post 0.4.27! I think we can close this out.