jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.2k stars 2.77k forks source link

Unable to use JAX pmap with CPU cores #19543

Closed Deepakgthomas closed 8 months ago

Deepakgthomas commented 8 months ago

Description

I am trying to use JAX pmap but I am getting the error that XLA devices aren't visible - Here's my code -

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax

from jax import pmap
import jax.numpy as jnp

out = pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)

Traceback (most recent call last):
  File "new.py", line 10, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 682, in parallel_callable
    pmap_executable = pmap_computation.compile()
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 923, in compile
    executable = UnloadedPmapExecutable.from_hlo(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 993, in from_hlo
    raise ValueError(msg.format(shards.num_global_shards,
jax._src.traceback_util.UnfilteredStackTrace: ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "new.py", line 10, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)
(tbd) thoma@thoma-Lenovo-Legion-5-15IMH05H:~/PycharmProjects/tbd$ python new.py
Traceback (most recent call last):
  File "new.py", line 10, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 682, in parallel_callable
    pmap_executable = pmap_computation.compile()
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 923, in compile
    executable = UnloadedPmapExecutable.from_hlo(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 993, in from_hlo
    raise ValueError(msg.format(shards.num_global_shards,
jax._src.traceback_util.UnfilteredStackTrace: ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "new.py", line 10, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)

What jax/jaxlib version are you using?

0.4.13 0.4.13

Which accelerator(s) are you using?

CPU/GPU

Additional system info?

uname_result(system='Linux', node='thoma-Lenovo-Legion-5-15IMH05H', release='6.5.0-15-generic', version='#15~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2', machine='x86_64', processor='x86_64')

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GTX 1660 Ti     Off | 00000000:01:00.0  On |                  N/A |
| N/A   48C    P8               4W /  80W |    538MiB /  6144MiB |     13%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      2429      G   /usr/lib/xorg/Xorg                          269MiB |
|    0   N/A  N/A      2635      G   /usr/bin/gnome-shell                         46MiB |
|    0   N/A  N/A      4276      G   ...seed-version=20240126-130123.803000       97MiB |
|    0   N/A  N/A      6368      G   ...,WinRetrieveSuggestionsOnlyOnDemand      122MiB |
+---------------------------------------------------------------------------------------+
jakevdp commented 8 months ago

It seems like this is a duplicate of #19541 – no need to ask this question multiple times. Thanks!

Deepakgthomas commented 8 months ago

My apologies. I thought that there was a bug. Thanks a lot for the help.

On Sat, Jan 27, 2024 at 10:57 AM Jake Vanderplas @.***> wrote:

It seems like this is a duplicate of #19541 https://github.com/google/jax/discussions/19541 – no need to ask this question multiple times. Thanks!

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/19543#issuecomment-1913259329, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFHXYWQWHWR7RX6HDU7NBT3YQUWZBAVCNFSM6AAAAABCNNALN2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMJTGI2TSMZSHE . You are receiving this because you authored the thread.Message ID: @.***>