jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

Jax on GPU and jit is 11x slower than Pytorch on GPU #13151

Closed kyrollosyanny closed 10 months ago

kyrollosyanny commented 1 year ago

Description

I migrated from Pytorch to Jax but I am noticing 11x slowdown on Jax. To test more generally, I used a simple function that sums the first three powers of a matrix

def fn(x):
    return x+x*x+x*x*x
x=np.random.randn(10000,10000).astype(dtype='float32')
jax_fn=jit(fn)
x=jnp.array(x)
%timeit  -n5 jax_fn(x).block_until_ready()

Jax takes 5.48 ms. This is running on GPU [by checking print(device_put(1, jax.devices()[0]).device_buffer.device()) ]

While same code on Pytorch on the same GPU runs in 459 microseconds which is 11x faster.

Im wondering where the slowdown is coming from and if there are any ways to speed it up?

Thanks a lot for your help

What jax/jaxlib version are you using?

pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Which accelerator(s) are you using?

GPU

Additional system info

Python V 3.10.6. WSL.

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.03    Driver Version: 516.25       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro P5200        On   | 00000000:01:00.0  On |                  N/A |
| N/A   39C    P8     7W /  N/A |  14497MiB / 16384MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A        28      G   /Xwayland                       N/A      |
|    0   N/A  N/A      4039      C   /python3.10                     N/A      |
|    0   N/A  N/A      4389      C   /python3.10                     N/A      |
+-----------------------------------------------------------------------------+
nicholasjng commented 1 year ago

I think your benchmark also measures compilation time of jax_fn, which happens in the first iteration of timeit (that's where JIT comes from, as far as I am aware - it gets compiled as late as on the first invocation with arguments).

Meanwhile in the JAX docs (FAQ section), there's this snippet which splits off compilation specifically from the benchmark:

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

So by restructuring your benchmark in this way, you should be able to obtain significantly better runtimes.

kyrollosyanny commented 1 year ago

Hi @nicholasjng,

Thanks for your response. I followed your suggestion but it did not change the timings significantly. Here are the outputs below. It is still way slower than Pytorch (5.41 ms for Jax vs 0.459 ms for Pytorch). Is there a way you can replicate this simple example and report your timings? I'm wondering if there is something inherent about jax that is slower than PyTorch or is it something in my version or installation?

%time x_jax = jax.device_put(x,device=gpus[0])  # measure JAX device transfer time
CPU times: user 47.9 ms, sys: 293 ms, total: 341 ms
Wall time: 623 ms

f_jit = jax.jit(fn)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
CPU times: user 36.6 ms, sys: 6.71 ms, total: 43.3 ms
Wall time: 207 ms
DeviceArray([[ 0.73889023, -0.00520617, -1.6567433 , ...,  0.25744492,  1.8352127 , -0.19825274],
             [ 0.12012546, -0.6986189 , -0.21019334, ..., -0.35663012, -3.3269477 ,  2.5003362 ],
             [ 0.72691584, -0.22199397, 23.771     , ..., -2.9344034 ,  2.3571334 , -0.25715792],
             ...,
             [ 0.20999207,  0.14257683,  0.6876698 , ...,  0.23829335,  8.207063  ,  3.1074247 ],
             [-0.10749066, -0.7747447 ,  6.4594707 , ...,  1.9226366 , -1.1958256 , -0.11689798],
             [ 0.07622809, -9.640307  ,  0.6005621 , ...,  0.65201366, -0.70378876, -0.78431636]], dtype=float32)

%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime
5.41 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nicholasjng commented 1 year ago

Can you share your Pytorch benchmark code?

Interestingly, printing the resulting jaxpr here,

>>> x=np.random.randn(100,100).astype(dtype='float32')
>>> x = jnp.array(x)
>>> jax.make_jaxpr(fn)(x)
{ lambda ; a:f32[100,100]. let
    b:f32[100,100] = mul a a
    c:f32[100,100] = add a b
    d:f32[100,100] = mul a a
    e:f32[100,100] = mul d a
    f:f32[100,100] = add c e
  in (f,) }

reveals that the same calculation (x*x) is run twice. I don't know if the compiler is smart enough to substitute the result of the first computation for d here (I can imagine it is, but in my opinion, it should be d:f32[100,100] = mul b a, and then e = add c d). If it's not, you can still work around it by binding the result of x*x to an intermediate value, although that's a bit verbose.

(I would still think JAX should be able to do this on its own, though, so maybe this is a more interesting case than just raw compute performance.)

jakevdp commented 1 year ago

A note on repeated computations: jaxprs don't contain any of this logic (they're just intermediate representations of the computations you write out in the Python code); all deduplication is done at the compiler level. You can confirm this by printing the compiled HLO:

print(jax.jit(fn).lower(x).compile().as_text())
HloModule jit_fn, entry_computation_layout={(f32[10]{0})->f32[10]{0}}

%fused_computation (param_0.2: f32[10]) -> f32[10] {
  %param_0.2 = f32[10]{0} parameter(0)
  %multiply.1 = f32[10]{0} multiply(f32[10]{0} %param_0.2, f32[10]{0} %param_0.2), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  %add.1 = f32[10]{0} add(f32[10]{0} %param_0.2, f32[10]{0} %multiply.1), metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  %multiply.0 = f32[10]{0} multiply(f32[10]{0} %multiply.1, f32[10]{0} %param_0.2), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  ROOT %add.0 = f32[10]{0} add(f32[10]{0} %add.1, f32[10]{0} %multiply.0), metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
}

ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] {
  %Arg_0.1 = f32[10]{0} parameter(0)
  ROOT %fusion = f32[10]{0} fusion(f32[10]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
}

The output is a bit hard to reed, but you can see here that multiply(param_0, param_0) is only computed once in the compiled version.

kyrollosyanny commented 1 year ago

Yeah absolutely. Here is the Pytorch code @nicholasjng .

import torch
import numpy as np
def fn(x):
    return x+x*x+x*x*x

x=np.random.randn(10000,10000).astype(dtype='float32')
x_torch=torch.tensor(x,device='cuda')
%timeit  fn(x_torch)

An interesting thing I am noticing is that the performance is slightly more comparable as array sizes are decreased. For example for 1000x1000 (instead of 10,000x10,000), Pytorch takes 97 us while Jax takes 393 us which is 4x slower instead of 11x slower.

Your insight about x*x is very interesting, for bigger and more complicated models, any advice on how to avoid duplicate computations?

Thanks a lot for the help and fast responses.

kyrollosyanny commented 1 year ago

To provide more context, the model that I am trying to port from Pytorch to Jax runs in 2 seconds on pytorch and 16 seconds on jax.

nicholasjng commented 1 year ago

I cannot reproduce this on Apple Silicon CPU, at least:

➜ cat bench.py
import jax
import jax.numpy as jnp
import numpy as np
import torch
import timeit

def fn(x):
    return x+x*x+x*x*x

y=np.random.randn(1000, 1000).astype(dtype='float32')
y_torch=torch.tensor(y)
y_jax=jnp.array(y)

jax_fn = jax.jit(fn)
jax_fn(y_jax).block_until_ready()

t = timeit.timeit(lambda: jax_fn(y_jax).block_until_ready(), number=1000)

print(f"JAX: {t * 1000} usec")

tt = timeit.timeit(lambda: fn(y_torch), number=1000)

print(f"Pytorch: {tt * 1000} usec")

Prints:

~ via 🐍 v3.10.8 (jax)
➜ python bench.py
JAX: 161.2808329955442 usec
Pytorch: 1106.6547080044984 usec

I tried this with torch.tensor(..., device="cpu") as well, without a change. So this could be a GPU problem? Can you post the results you obtain with this script? Then we have a frame of reference over which we can compare the numbers.

kyrollosyanny commented 1 year ago

You are absolutely right! Running this code on cpu gives much better performance for jax.

%timeit -n1000 fn(x_torch)  # for pytorch
3.02 ms ± 89.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit -n1000 f_jit(x_jax).block_until_ready()  # measure JAX runtime
242 µs ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

On cpu Jax is 12x FASTER!

I'm using WSL and a Quadro P5200, CUDA Version: 11.7 but unsure about cuDNN version for the GPU tests from before.

yhtang commented 1 year ago

Just like we need block_until_ready() to properly profile GPUs in JAX, for PyTorch we will need torch.cuda.synchronize().

On a Colab T4 instance:

%timeit -n 100 f_jit(x_jax).block_until_ready()  # measure JAX runtime

3.85 ms ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

x_torch = torch.tensor(x, device='cuda')

def fn_torch(x):
    r = x+x*x+x*x*x
    torch.cuda.synchronize()

with torch.no_grad():
    %timeit -n 10 fn_torch(x_torch)

21.3 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Hence JAX is indeed faster 😄

hawkinsp commented 1 year ago

@kyrollosyanny Since you said your end-to-end model is slower under JAX, I suspect that what has happened is that your microbenchmark is no longer representative of your original benchmark. Can you make a more representative benchmark?

hawkinsp commented 10 months ago

We never got any more information about the actual end to end model, and the only benchmark given was faster under JAX. There's not much more we can do here without more information. Closing.