jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.31k stars 2.78k forks source link

Nesting `vmap` within `pmap` using `jax.pure_callback` segfaults #17001

Open Justin-Tan opened 1 year ago

Justin-Tan commented 1 year ago

Description

Some background, I'm trying to parallelize a CPU-intensive computation using a callback to some scipy.optimize routines using jax.pure_callback across the available CPUs on my machine. Using vmap and pmap separately on jax.pure_callback works when pmaping over available CPUs, but not nesting the two. Here's a toy example below:

import os, multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count())

from jax import config
config.update("jax_enable_x64", True)
import jax
from jax import vmap, pmap
import jax.numpy as jnp
import numpy as np

devs = jax.devices('cpu')
n_devs = len(devs)

def _np_sq(x,b): return np.square(x-b)

def np_sq(x,b):
    shape_dtype = jax.ShapeDtypeStruct(shape=x.shape, 
                                    dtype=x.dtype)
    res = jax.pure_callback(_np_sq, shape_dtype, x, b,
                            vectorized=False)
    return res
x = np.random.randn(2, 3)
b = np.random.randn(2, 3)
vmap(np_sq)(x,b)

What jax/jaxlib version are you using?

0.4.14

Which accelerator(s) are you using?

CPUs, on a machine with a GPU. The GPU is not used.

Additional system info

Ubuntu 22.04.2 LTS (GNU/Linux 5.15.0-72-generic x86_64)

NVIDIA GPU info

No response

rajasekharporeddy commented 8 months ago

Hi @Justin-Tan

Looks like this issue has been resolved in later versions of JAX. I executed the provided repro code with JAX version 0.4.23 on Google Colab using GPU run time. It executes without any error.

import os, multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count())

from jax import config
config.update("jax_enable_x64", True)
import jax
from jax import vmap, pmap
import jax.numpy as jnp
import numpy as np

print(f"Jax version: {jax.__version__}")

devs = jax.devices('cpu')
print(f"Devices: {devs}")
n_devs = len(devs)
print(f"Number of devices: {n_devs}")

def _np_sq(x,b): return np.square(x-b)

def np_sq(x,b):
    shape_dtype = jax.ShapeDtypeStruct(shape=x.shape,
                                    dtype=x.dtype)
    res = jax.pure_callback(_np_sq, shape_dtype, x, b,
                            vectorized=False)
    return res

Output:

Jax version: 0.4.23
Devices: [CpuDevice(id=0), CpuDevice(id=1)]
Number of devices: 2

vmap:

x = np.random.randn(2, 3)
b = np.random.randn(2, 3)
print(f"vmap:{vmap(np_sq)(x,b)}")

Output:

vmap:[[0.11448556 0.5447552  3.27006149]
 [0.16259046 0.12088213 0.45500056]]

pmap:

y = np.random.randn(n_devs, 3)
c = np.random.randn(n_devs, 3)
print(f"pmap: {pmap(np_sq, devices=devs)(y,c)}")

Output:

pmap: [[7.33969863 0.19640176 0.60541126]
 [5.45969277 2.96010452 0.0125925 ]]

Nesting vmap with pmap:

z = np.random.randn(n_devs, 2, 3)
d = np.random.randn(n_devs, 2, 3)
print(f"Nesting vmap within pmap: {pmap(vmap(np_sq), devices=devs)(y,c)}")

Output:

Nesting vmap within pmap: [[7.33969863 0.19640176 0.60541126]
 [5.45969277 2.96010452 0.0125925 ]]

I have also verified the results in cloud VM with 4 CPUs and 4GPUs using the latest JAX version 0.4.24. It produces the following result.

Jax version: 0.4.24
Devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Number of devices: 4
vmap:[[0.08676551 0.28720522 0.03171861]
 [0.03151458 1.35523808 0.01098814]]
pmap: [[1.10615892 1.54875127 0.09205769]
 [1.87091675 3.67551264 1.17709292]
 [1.39692925 0.00487045 0.88255363]
 [0.9227732  0.07558887 0.21490667]]
Nesting vmap within pmap: [[1.10615892 1.54875127 0.09205769]
 [1.87091675 3.67551264 1.17709292]
 [1.39692925 0.00487045 0.88255363]
 [0.9227732  0.07558887 0.21490667]]

Please find the gist on colab for reference.

Thank you