google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.3k stars 2.68k forks source link

JAX numpy random sampling slow things down compared to Normal numpy sampling #6476

Closed hudsonchen closed 5 months ago

hudsonchen commented 3 years ago

Hi all,

I am working on a research project on Bayesian neural networks, and I implemented a neural network with stochastic weights by making minor modifications to the standard haiku deep learning modules.

To train the Bayesian neural network using variational inference, I need to draw multiple samples from the weight distributions for every gradient step, but JAX sampling appears to be slowing things down and takes ~3X times compared to normal numpy sampling. I have tried jitting the sampling function, and the random spliitting function as well. But that doesn’t seem to help. I also tried using vmap instead of a loop, but that only resulted in a minor speed up.

For further details on the comparison between numpy and JAX sampling (where I fix all other set-ups and only change two different ways of sampling), please see the notebook below: JAX numpy vs Normal numpy

I would greatly appreciate any help or pointers for why JAX sampling seems to be so slow.

Thanks a lot!

selamw1 commented 5 months ago

Hi @hudsonchen Unfortunately, the notebook you shared, titled JAX numpy vs Normal numpy, is no longer available. Below sample code snippet provided demonstrates that JAX random sampling is faster than using standard NumPy.

Please note that the first time you run a JAX operation, it includes compilation time, which might affect the timing comparison; however, I assume this was not the case for you. Another potential issue could have been if the user passed in a shape that was not a tuple. For more details, please refer to the closed issue #1038 . JAX FAQ is also helpful for speed comparisons: Is JAX faster than NumPy?

import jax
import jax.numpy as jnp
import numpy as np

def sample_jax(key, shape):
  return jax.random.normal(key, shape)

def sample_numpy(shape):
  return np.random.normal(size=shape)

key = jax.random.PRNGKey(0)
shape = (100000,)

print("JAX: ")
%timeit sample_jax(key, shape)

print("NumPy")
%timeit sample_numpy(shape)

Output:

JAX: 
1.38 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
NumPy
3.17 ms ± 90.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Here I added a gist for your reference.

Thank you!

jakevdp commented 5 months ago

Thanks @selamw1 – when running micro-benchmarks, please keep in mind the best practices for Benchmarking JAX code. With these in mind, I find these results on a Colab CPU runtime:

import jax
import jax.numpy as jnp
import numpy as np

def sample_jax(key, shape):
  return jax.random.normal(key, shape)

def sample_numpy(shape):
  return np.random.normal(size=shape)

sample_jax_jit = jax.jit(sample_jax, static_argnames=['shape'])

key = jax.random.PRNGKey(0)
shape = (100000,)

_ = sample_jax(key, shape).block_until_ready()
_ = sample_jax_jit(key, shape).block_until_ready()

print("JAX: ")
%timeit sample_jax(key, shape).block_until_ready()

print("JAX JIT: ")
%timeit sample_jax_jit(key, shape).block_until_ready()

print("NumPy")
%timeit sample_numpy(shape)
JAX: 
1.72 ms ± 506 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
JAX JIT: 
833 µs ± 33.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
NumPy
3.01 ms ± 57.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

On a Colab GPU runtime I find the following:

JAX: 
547 µs ± 165 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
JAX JIT: 
310 µs ± 68.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
NumPy
2.85 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Also, for a more general discussion of performance of JAX vs. NumPy, see FAQ: Is JAX faster than NumPy?.