Open Justin-Tan opened 1 year ago
Hi @Justin-Tan
Looks like this issue has been resolved in later versions of JAX. I executed the provided repro code with JAX version 0.4.23 on Google Colab using GPU run time. It executes without any error.
import os, multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count())
from jax import config
config.update("jax_enable_x64", True)
import jax
from jax import vmap, pmap
import jax.numpy as jnp
import numpy as np
print(f"Jax version: {jax.__version__}")
devs = jax.devices('cpu')
print(f"Devices: {devs}")
n_devs = len(devs)
print(f"Number of devices: {n_devs}")
def _np_sq(x,b): return np.square(x-b)
def np_sq(x,b):
shape_dtype = jax.ShapeDtypeStruct(shape=x.shape,
dtype=x.dtype)
res = jax.pure_callback(_np_sq, shape_dtype, x, b,
vectorized=False)
return res
Output:
Jax version: 0.4.23
Devices: [CpuDevice(id=0), CpuDevice(id=1)]
Number of devices: 2
vmap
:
x = np.random.randn(2, 3)
b = np.random.randn(2, 3)
print(f"vmap:{vmap(np_sq)(x,b)}")
Output:
vmap:[[0.11448556 0.5447552 3.27006149]
[0.16259046 0.12088213 0.45500056]]
pmap
:
y = np.random.randn(n_devs, 3)
c = np.random.randn(n_devs, 3)
print(f"pmap: {pmap(np_sq, devices=devs)(y,c)}")
Output:
pmap: [[7.33969863 0.19640176 0.60541126]
[5.45969277 2.96010452 0.0125925 ]]
Nesting vmap
with pmap
:
z = np.random.randn(n_devs, 2, 3)
d = np.random.randn(n_devs, 2, 3)
print(f"Nesting vmap within pmap: {pmap(vmap(np_sq), devices=devs)(y,c)}")
Output:
Nesting vmap within pmap: [[7.33969863 0.19640176 0.60541126]
[5.45969277 2.96010452 0.0125925 ]]
I have also verified the results in cloud VM with 4 CPUs and 4GPUs using the latest JAX version 0.4.24. It produces the following result.
Jax version: 0.4.24
Devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Number of devices: 4
vmap:[[0.08676551 0.28720522 0.03171861]
[0.03151458 1.35523808 0.01098814]]
pmap: [[1.10615892 1.54875127 0.09205769]
[1.87091675 3.67551264 1.17709292]
[1.39692925 0.00487045 0.88255363]
[0.9227732 0.07558887 0.21490667]]
Nesting vmap within pmap: [[1.10615892 1.54875127 0.09205769]
[1.87091675 3.67551264 1.17709292]
[1.39692925 0.00487045 0.88255363]
[0.9227732 0.07558887 0.21490667]]
Please find the gist on colab for reference.
Thank you
Description
Some background, I'm trying to parallelize a CPU-intensive computation using a callback to some
scipy.optimize
routines usingjax.pure_callback
across the available CPUs on my machine. Usingvmap
andpmap
separately onjax.pure_callback
works whenpmap
ing over available CPUs, but not nesting the two. Here's a toy example below:pmap works:
but not nesting the two:
Looks like the reverse composition
vmap
\circpmap
works. though I'm not sure if this is advisable. Something similar also happens usingshmap
, although I haven't tried that with the toy example above. Note: for just a single CPU the code runs fine on my machine.What jax/jaxlib version are you using?
0.4.14
Which accelerator(s) are you using?
CPUs, on a machine with a GPU. The GPU is not used.
Additional system info
Ubuntu 22.04.2 LTS (GNU/Linux 5.15.0-72-generic x86_64)
NVIDIA GPU info
No response