google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.2k stars 2.68k forks source link

Cannot Share GPUs between processes when using JAX with SLURM #10021

Open chrisgrimm opened 2 years ago

chrisgrimm commented 2 years ago

It's possible that this is a limitation of SLURM rather than JAX, but I think it might be worth starting a discussion about this here to get to the bottom of it.

I'm trying to launch scripts using JAX with SLURM. Particularly, I want to launch multiple processes that use JAX and have them share a single allocated GPU. The issue I'm running into is that one process will be able to detect the GPU while the rest will not. The remaining processes will default to using CPU after printing the following warning:

2022-03-23 20:23:46.386347: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:205] unable to create 
StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to 
cuDevicePrimaryCtxRetain: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal

It's worth noting that running nvidia-smi from within each process succeeds and shows statistics for the GPU I am attempting to share.

Does anyone have an idea why this CUDA_ERROR_INVALID_DEVICE error might be arising in my context?

skye commented 2 years ago

Can you try setting the env var XLA_PYTHON_CLIENT_PREALLOCATE=false? See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more information and similar options. I think the issue might be that the first jax process is pre-allocating most of the GPU memory, causing other processes to OOM (sometimes OOMs show up in confusing ways).

hawkinsp commented 2 years ago

I'm not sure what's going on here, but I'm wondering what exactly SLURM does to expose a particular GPU to a task. Does it set CUDA_VISIBLE_DEVICES? Something else? Perhaps it sets some other environment variables?

Since I have no access to or experience with SLURM I'd want to see if we can reproduce the problem without using SLURM. I wonder if we can replicate what SLURM does to the task in isolation, without using SLURM? That would help us replicate the problem and allow us to debug it.

sudhakarsingh27 commented 1 year ago

@chrisgrimm were you able to resolve this?

sweetice commented 1 year ago

Similar issue. Waiting for a solution... I cannot run multiple GPUs in a GPU.

maxosmith commented 1 year ago

Here's a minimum reproducable example, where the second GPU node will not be able to bind to the GPU:

import os
import time

import jax
import launchpad as lp
import tensorflow as tf
from absl import app
from launchpad.nodes.python.local_multi_processing import PythonProcess

class DeviceTester:
    def test(self):
        print(os.environ)
        print("Backends: ", jax.default_backend())
        print("Devices: ", jax.devices())

    def run(self):
        time.sleep(5)
        lp.stop()

def _build_test_node():
    return lp.CourierNode(DeviceTester)

def main(_):
    # Test independent of Launchpad.
    print("\n\nLocal")
    print(os.environ)
    test = DeviceTester()
    test.test()

    # Test GPU accessibility on Launchpad.
    print("\n\nLaunchpad")

    program = lp.Program(name="experiment")
    handles = []

    with program.group("test_cpu"):
        handles.append(program.add_node(_build_test_node()))

    with program.group("test_gpu"):
        handles.append(program.add_node(_build_test_node()))
        handles.append(program.add_node(_build_test_node()))

    lp.launch(
        program,
        launch_type=lp.LaunchType.LOCAL_MULTI_PROCESSING,
        terminal="current_terminal",
        local_resources={
            "test_cpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "",
                    "JAX_PLATFORM_NAME": "cpu",
                }
            ),
            "test_gpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "0",
                    "XLA_PYTHON_CLIENT_MEM_FRACTION": ".2",
                    "XLA_PYTHON_CLIENT_PREALLOCATE": "false",
                    "JAX_PLATFORM_NAME": "gpu",
                }
            ),
        },
    )

    for handle in handles:
        handle.dereference().test()

if __name__ == "__main__":
    # Provide access to --jax_backend_target and --jax_xla_backend flags.
    jax.config.config_with_absl()
    # Binary should use CPU
    jax.config.update("jax_platform_name", "cpu")
    tf.config.experimental.set_visible_devices([], "GPU")
    app.run(main)
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "launchpad_gpu_test.py", line 80, in <module>
    app.run(main)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "launchpad_gpu_test.py", line 70, in main
    handle.dereference().test()
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/courier/python/client.py", line 52, in inner_function
    raise translate_status(e.status) from e
pybind11_abseil.status.StatusNotOk: Python exception was raised on the server:
Traceback (most recent call last):
  File "launchpad_gpu_test.py", line 14, in test
    print("Backends: ", jax.default_backend())
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 490, in default_backend
    return get_backend(None).platform
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 427, in get_backend
    return _get_backend_uncached(platform)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 413, in _get_backend_uncached
    platform = canonicalize_platform(platform)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 294, in canonicalize_platform
    raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu

Both GPU nodes have the same env variables, which are:

[test_gpu/1] environ({'CONDA_SHLVL': '2', 'LD_LIBRARY_PATH': '/home/mxsmith/.conda/envs/model38/lib', 'LS_COLORS': 'no=00:di=34;01:tw=34;01:ow=34;01:fi=00:ln=00:pi=00:so=00:bd=00:cd=00:or=00:mi=00:ex=00:*.sh=31:*.exe=31:*.bat=31', 'CONDA_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/conda', 'SRUN_DEBUG': '3', 'SLURM_STEP_ID': '0', 'SLURM_STEP_GPUS': '0', 'SLURM_NODEID': '0', 'SLURM_TASK_PID': '827752', 'HTTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SSH_CONNECTION': '141.211.21.82 51992 141.211.192.38 22', 'SLURM_PRIO_PROCESS': '0', 'SLURM_CPU_BIND_VERBOSE': 'quiet', 'IEX_TOKEN': 'pk_6cd11420a64b4d5a856ac31281418f38', 'LANG': 'en_US.UTF-8', 'SLURM_SUBMIT_DIR': '/home/mxsmith', 'HISTCONTROL': 'ignoreboth:erasedups', 'HOSTNAME': 'gl1520.arc-ts.umich.edu', 'OLDPWD': '/home/mxsmith', 'SLURM_STEPID': '0', 'SLURM_SRUN_COMM_HOST': '141.211.192.38', 'EDITOR': 'emacs', 'SLURM_DISTRIBUTION': 'cyclic', 'ROCR_VISIBLE_DEVICES': '0', 'CONDA_PREFIX': '/home/mxsmith/.conda/envs/model38', 'SQUEUE_FORMAT': '%.18i %.9P %40j %.8u %.2t %.10M %.20R', 'SLURM_PROCID': '0', 'SLURM_JOB_GID': '99464869', 'SLURM_CPU_BIND': 'quiet,mask_cpu:0x00000001', 'SLURMD_NODENAME': 'gl1520', 'GIT_EDITOR': 'emacs', 'SLURM_TASKS_PER_NODE': '1', 'S_COLORS': 'auto', '_CE_M': '', 'XLA_PYTHON_CLIENT_PREALLOCATE': 'false', 'TF2_BEHAVIOR': '1', 'XDG_SESSION_ID': '11093', 'SLURM_NNODES': '1', 'USER': 'mxsmith', 'SLURM_LAUNCH_NODE_IPADDR': '141.211.192.38', 'CONDA_PREFIX_1': '/sw/pkgs/arc/python3.9-anaconda/2021.11', 'SLURM_STEP_TASKS_PER_NODE': '1', 'MATPLOTLIBRC': '/home/mxsmith/profile/matplotlib', 'FTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'PWD': '/home/mxsmith/projects', 'SSH_ASKPASS': '/usr/libexec/openssh/gnome-ssh-askpass', 'SLURM_JOB_NODELIST': 'gl1520', 'HOME': '/home/mxsmith', 'SLURM_CLUSTER_NAME': 'greatlakes', 'CONDA_PYTHON_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/python', 'SLURM_NODELIST': 'gl1520', 'SLURM_GPUS_ON_NODE': '1', 'SSH_CLIENT': '141.211.21.82 51992 22', 'LMOD_VERSION': '8.6.14', 'SLURM_NTASKS': '1', 'TMUX': '/tmp/tmux-99464869/default,827956,0', 'rsync_proxy': 'proxy.arc-ts.umich.edu:3128', 'SLURM_UMASK': '0002', 'https_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'KRB5CCNAME': 'FILE:/tmp/krb5cc_99464869_TngfOC', 'TF_CPP_MIN_LOG_LEVEL': '1', 'SLURM_JOB_CPUS_PER_NODE': '4', 'BASH_ENV': '/sw/lmod/lmod/init/bash', 'XDG_DATA_DIRS': '/home/mxsmith/.local/share/flatpak/exports/share:/var/lib/flatpak/exports/share:/usr/local/share:/usr/share', 'AUTOJUMP_ERROR_PATH': '/home/mxsmith/.local/share/autojump/errors.log', 'SLURM_TOPOLOGY_ADDR': 'gl1520', 'http_proxy': 'http://proxy.arc-ts.umich.edu:3128/', '_CE_CONDA': '', 'SLURM_WORKING_CLUSTER': 'greatlakes:glctld:6817:9472:109', 'SLURM_STEP_NODELIST': 'gl1520', 'SLURM_JOB_NAME': 'bash', 'SLURM_SRUN_COMM_PORT': '60207', 'TMPDIR': '/tmp', 'LMOD_sys': 'Linux', 'SLURM_JOBID': '45908601', 'JAX_PLATFORM_NAME': 'gpu', 'SLURM_CONF': '/var/spool/slurmd.spool/conf-cache/slurm.conf', 'LMOD_AVAIL_STYLE': 'grouped', 'no_proxy': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'LMOD_ROOT': '/sw/lmod', 'SLURM_JOB_QOS': 'normal', 'SLURM_TOPOLOGY_ADDR_PATTERN': 'node', 'CONDA_PROMPT_MODIFIER': '(model38) ', 'SSH_TTY': '/dev/pts/80', 'NO_PROXY': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'MAIL': '/var/spool/mail/mxsmith', 'HTTPS_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SLURM_CPUS_ON_NODE': '4', 'XLA_PYTHON_CLIENT_MEM_FRACTION': '.2', 'VISUAL': 'emacs', 'SLURM_JOB_NUM_NODES': '1', 'AUTOJUMP_SOURCED': '1', 'SHELL': '/bin/bash', 'TERM': 'xterm-256color', 'SLURM_JOB_UID': '99464869', 'SLURM_JOB_PARTITION': 'spgpu', 'SLURM_PTY_WIN_ROW': '49', 'SLURM_CPU_BIND_LIST': '0x00000001', 'SLURM_JOB_USER': 'mxsmith', 'CUDA_VISIBLE_DEVICES': '0', 'SLURM_PTY_WIN_COL': '97', 'TMUX_PANE': '%3', 'SLURM_NPROCS': '1', 'SHLVL': '3', 'SLURM_SUBMIT_HOST': 'gl-login1.arc-ts.umich.edu', 'SLURM_JOB_ACCOUNT': 'wellman0', 'MANPATH': '/sw/lmod/lmod/share/man:/usr/local/share/man:/usr/share/man:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man:/opt/slurm/share/man/:/opt/TurboVNC/man/:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man', 'SLURM_STEP_LAUNCHER_PORT': '60207', 'MODULEPATH': '/sw/lmod/lmod/modulefiles/Core:/sw/modules/Core:/sw/modules/Collections', 'SLURM_PTY_PORT': '60206', 'SLURM_GTIDS': '0', 'LOGNAME': 'mxsmith', 'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/99464869/bus', 'XDG_RUNTIME_DIR': '/run/user/99464869', 'MODULEPATH_ROOT': '/sw/modules', 'LMOD_PACKAGE_PATH': '/sw/lmod', 'PATH': '/home/mxsmith/software/bin:/home/mxsmith/software/bin:/home/mxsmith/.conda/envs/model38/bin:/sw/pkgs/arc/python3.9-anaconda/2021.11/condabin:/home/mxsmith/software/bin:/opt/TurboVNC/bin:/opt/slurm/bin:/opt/slurm/sbin:/sw/pkgs/arc/usertools/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/usr/lpp/mmfs/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin', 'SLURM_JOB_ID': '45908601', 'SLURM_CPU_BIND_TYPE': 'mask_cpu:', 'SLURM_STEP_NUM_TASKS': '1', 'MODULESHOME': '/sw/lmod/lmod', 'CONDA_DEFAULT_ENV': 'model38', 'LMOD_SETTARG_FULL_SUPPORT': 'no', 'HISTSIZE': '1000', 'LMOD_PKG': '/sw/lmod/lmod', 'CLUSTER_NAME': 'greatlakes', 'SLURM_STEP_NUM_NODES': '1', 'ftp_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'RSYNC_PROXY': 'proxy.arc-ts.umich.edu:3128', 'LMOD_CMD': '/sw/lmod/lmod/libexec/lmod', 'SLURM_LOCALID': '0', 'GPU_DEVICE_ORDINAL': '0', 'LESSOPEN': '||/usr/bin/lesspipe.sh %s', 'LMOD_DIR': '/sw/lmod/lmod/libexec', 'BASH_FUNC_module%%': '() {  local __lmod_my_status;\n local __lmod_sh_dbg;\n if [ -z "${LMOD_SH_DBG_ON+x}" ]; then\n case "$-" in \n *v*x*)\n __lmod_sh_dbg=\'vx\'\n ;;\n *v*)\n __lmod_sh_dbg=\'v\'\n ;;\n *x*)\n __lmod_sh_dbg=\'x\'\n ;;\n esac;\n fi;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n set +$__lmod_sh_dbg;\n echo "Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod\'s output" 1>&2;\n fi;\n eval "$($LMOD_CMD bash "$@")" && eval $(${LMOD_SETTARG_CMD:-:} -s sh);\n __lmod_my_status=$?;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n echo "Shell debugging restarted" 1>&2;\n set -$__lmod_sh_dbg;\n fi;\n return $__lmod_my_status\n}', 'BASH_FUNC_ml%%': '() {  eval "$($LMOD_DIR/ml_cmd "$@")"\n}'})
[test_gpu/0] environ({'CONDA_SHLVL': '2', 'LD_LIBRARY_PATH': '/home/mxsmith/.conda/envs/model38/lib', 'LS_COLORS': 'no=00:di=34;01:tw=34;01:ow=34;01:fi=00:ln=00:pi=00:so=00:bd=00:cd=00:or=00:mi=00:ex=00:*.sh=31:*.exe=31:*.bat=31', 'CONDA_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/conda', 'SRUN_DEBUG': '3', 'SLURM_STEP_ID': '0', 'SLURM_STEP_GPUS': '0', 'SLURM_NODEID': '0', 'SLURM_TASK_PID': '827752', 'HTTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SSH_CONNECTION': '141.211.21.82 51992 141.211.192.38 22', 'SLURM_PRIO_PROCESS': '0', 'SLURM_CPU_BIND_VERBOSE': 'quiet', 'IEX_TOKEN': 'pk_6cd11420a64b4d5a856ac31281418f38', 'LANG': 'en_US.UTF-8', 'SLURM_SUBMIT_DIR': '/home/mxsmith', 'HISTCONTROL': 'ignoreboth:erasedups', 'HOSTNAME': 'gl1520.arc-ts.umich.edu', 'OLDPWD': '/home/mxsmith', 'SLURM_STEPID': '0', 'SLURM_SRUN_COMM_HOST': '141.211.192.38', 'EDITOR': 'emacs', 'SLURM_DISTRIBUTION': 'cyclic', 'ROCR_VISIBLE_DEVICES': '0', 'CONDA_PREFIX': '/home/mxsmith/.conda/envs/model38', 'SQUEUE_FORMAT': '%.18i %.9P %40j %.8u %.2t %.10M %.20R', 'SLURM_PROCID': '0', 'SLURM_JOB_GID': '99464869', 'SLURM_CPU_BIND': 'quiet,mask_cpu:0x00000001', 'SLURMD_NODENAME': 'gl1520', 'GIT_EDITOR': 'emacs', 'SLURM_TASKS_PER_NODE': '1', 'S_COLORS': 'auto', '_CE_M': '', 'XLA_PYTHON_CLIENT_PREALLOCATE': 'false', 'TF2_BEHAVIOR': '1', 'XDG_SESSION_ID': '11093', 'SLURM_NNODES': '1', 'USER': 'mxsmith', 'SLURM_LAUNCH_NODE_IPADDR': '141.211.192.38', 'CONDA_PREFIX_1': '/sw/pkgs/arc/python3.9-anaconda/2021.11', 'SLURM_STEP_TASKS_PER_NODE': '1', 'MATPLOTLIBRC': '/home/mxsmith/profile/matplotlib', 'FTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'PWD': '/home/mxsmith/projects', 'SSH_ASKPASS': '/usr/libexec/openssh/gnome-ssh-askpass', 'SLURM_JOB_NODELIST': 'gl1520', 'HOME': '/home/mxsmith', 'SLURM_CLUSTER_NAME': 'greatlakes', 'CONDA_PYTHON_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/python', 'SLURM_NODELIST': 'gl1520', 'SLURM_GPUS_ON_NODE': '1', 'SSH_CLIENT': '141.211.21.82 51992 22', 'LMOD_VERSION': '8.6.14', 'SLURM_NTASKS': '1', 'TMUX': '/tmp/tmux-99464869/default,827956,0', 'rsync_proxy': 'proxy.arc-ts.umich.edu:3128', 'SLURM_UMASK': '0002', 'https_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'KRB5CCNAME': 'FILE:/tmp/krb5cc_99464869_TngfOC', 'TF_CPP_MIN_LOG_LEVEL': '1', 'SLURM_JOB_CPUS_PER_NODE': '4', 'BASH_ENV': '/sw/lmod/lmod/init/bash', 'XDG_DATA_DIRS': '/home/mxsmith/.local/share/flatpak/exports/share:/var/lib/flatpak/exports/share:/usr/local/share:/usr/share', 'AUTOJUMP_ERROR_PATH': '/home/mxsmith/.local/share/autojump/errors.log', 'SLURM_TOPOLOGY_ADDR': 'gl1520', 'http_proxy': 'http://proxy.arc-ts.umich.edu:3128/', '_CE_CONDA': '', 'SLURM_WORKING_CLUSTER': 'greatlakes:glctld:6817:9472:109', 'SLURM_STEP_NODELIST': 'gl1520', 'SLURM_JOB_NAME': 'bash', 'SLURM_SRUN_COMM_PORT': '60207', 'TMPDIR': '/tmp', 'LMOD_sys': 'Linux', 'SLURM_JOBID': '45908601', 'JAX_PLATFORM_NAME': 'gpu', 'SLURM_CONF': '/var/spool/slurmd.spool/conf-cache/slurm.conf', 'LMOD_AVAIL_STYLE': 'grouped', 'no_proxy': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'LMOD_ROOT': '/sw/lmod', 'SLURM_JOB_QOS': 'normal', 'SLURM_TOPOLOGY_ADDR_PATTERN': 'node', 'CONDA_PROMPT_MODIFIER': '(model38) ', 'SSH_TTY': '/dev/pts/80', 'NO_PROXY': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'MAIL': '/var/spool/mail/mxsmith', 'HTTPS_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SLURM_CPUS_ON_NODE': '4', 'XLA_PYTHON_CLIENT_MEM_FRACTION': '.2', 'VISUAL': 'emacs', 'SLURM_JOB_NUM_NODES': '1', 'AUTOJUMP_SOURCED': '1', 'SHELL': '/bin/bash', 'TERM': 'xterm-256color', 'SLURM_JOB_UID': '99464869', 'SLURM_JOB_PARTITION': 'spgpu', 'SLURM_PTY_WIN_ROW': '49', 'SLURM_CPU_BIND_LIST': '0x00000001', 'SLURM_JOB_USER': 'mxsmith', 'CUDA_VISIBLE_DEVICES': '0', 'SLURM_PTY_WIN_COL': '97', 'TMUX_PANE': '%3', 'SLURM_NPROCS': '1', 'SHLVL': '3', 'SLURM_SUBMIT_HOST': 'gl-login1.arc-ts.umich.edu', 'SLURM_JOB_ACCOUNT': 'wellman0', 'MANPATH': '/sw/lmod/lmod/share/man:/usr/local/share/man:/usr/share/man:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man:/opt/slurm/share/man/:/opt/TurboVNC/man/:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man', 'SLURM_STEP_LAUNCHER_PORT': '60207', 'MODULEPATH': '/sw/lmod/lmod/modulefiles/Core:/sw/modules/Core:/sw/modules/Collections', 'SLURM_PTY_PORT': '60206', 'SLURM_GTIDS': '0', 'LOGNAME': 'mxsmith', 'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/99464869/bus', 'XDG_RUNTIME_DIR': '/run/user/99464869', 'MODULEPATH_ROOT': '/sw/modules', 'LMOD_PACKAGE_PATH': '/sw/lmod', 'PATH': '/home/mxsmith/software/bin:/home/mxsmith/software/bin:/home/mxsmith/.conda/envs/model38/bin:/sw/pkgs/arc/python3.9-anaconda/2021.11/condabin:/home/mxsmith/software/bin:/opt/TurboVNC/bin:/opt/slurm/bin:/opt/slurm/sbin:/sw/pkgs/arc/usertools/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/usr/lpp/mmfs/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin', 'SLURM_JOB_ID': '45908601', 'SLURM_CPU_BIND_TYPE': 'mask_cpu:', 'SLURM_STEP_NUM_TASKS': '1', 'MODULESHOME': '/sw/lmod/lmod', 'CONDA_DEFAULT_ENV': 'model38', 'LMOD_SETTARG_FULL_SUPPORT': 'no', 'HISTSIZE': '1000', 'LMOD_PKG': '/sw/lmod/lmod', 'CLUSTER_NAME': 'greatlakes', 'SLURM_STEP_NUM_NODES': '1', 'ftp_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'RSYNC_PROXY': 'proxy.arc-ts.umich.edu:3128', 'LMOD_CMD': '/sw/lmod/lmod/libexec/lmod', 'SLURM_LOCALID': '0', 'GPU_DEVICE_ORDINAL': '0', 'LESSOPEN': '||/usr/bin/lesspipe.sh %s', 'LMOD_DIR': '/sw/lmod/lmod/libexec', 'BASH_FUNC_module%%': '() {  local __lmod_my_status;\n local __lmod_sh_dbg;\n if [ -z "${LMOD_SH_DBG_ON+x}" ]; then\n case "$-" in \n *v*x*)\n __lmod_sh_dbg=\'vx\'\n ;;\n *v*)\n __lmod_sh_dbg=\'v\'\n ;;\n *x*)\n __lmod_sh_dbg=\'x\'\n ;;\n esac;\n fi;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n set +$__lmod_sh_dbg;\n echo "Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod\'s output" 1>&2;\n fi;\n eval "$($LMOD_CMD bash "$@")" && eval $(${LMOD_SETTARG_CMD:-:} -s sh);\n __lmod_my_status=$?;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n echo "Shell debugging restarted" 1>&2;\n set -$__lmod_sh_dbg;\n fi;\n return $__lmod_my_status\n}', 'BASH_FUNC_ml%%': '() {  eval "$($LMOD_DIR/ml_cmd "$@")"\n}'})
tensorflow==2.8.3
jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1671027717961/work
jaxlib==0.4.1+cuda11.cudnn86
dm-launchpad==0.5.2p
maxosmith commented 1 year ago

If you add this to the nodes you can see that cuda backend is missing from the platform processing

from jax._src.lib import xla_bridge
print("Backends: ", xla_bridge.backends())
[test_gpu/0] Backends:  {'interpreter': <jaxlib.xla_extension.Client object at 0x14b25b27ac30>, 'cpu': <jaxlib.xla_extension.Client object at 0x14b24d95b1b0>, 'cuda': <jaxlib.xla_extension.Client object at 0x14b24c2f7570>}
[test_gpu/1] Backends:  {'interpreter': <jaxlib.xla_extension.Client object at 0x14b3e2dc0f30>, 'cpu': <jaxlib.xla_extension.Client object at 0x14b3e2dc09f0>}

If you also add:

from jax.config import config
print(config.jax_platforms)

They're both None, prompting all of the backend factories to be run.

[test_gpu/1] {'interpreter': (<function make_interpreter_client at 0x145e83579b80>, -100), 'cpu': (functools.partial(<function make_cpu_client at 0x145e8359c790>, use_tfrt=True), 0), 'tpu_driver': (<function _make_tpu_driver_client at 0x145e728bfdc0>, 100), 'cuda': (functools.partial(<function make_gpu_client at 0x145e728bff70>, platform_name='cuda', visible_devices_flag='jax_cuda_visible_devices'), 200), 'rocm': (functools.partial(<function make_gpu_client at 0x145e728bff70>, platform_name='rocm', visible_devices_flag='jax_rocm_visible_devices'), 200), 'tpu': (functools.partial(<function tpu_client_timer_callback at 0x145e728bfe50>, timer_secs=60.0), 300), 'plugin': (<function make_plugin_device_client at 0x145e8359caf0>, 400)}
[test_gpu/0] {'interpreter': (<function make_interpreter_client at 0x150b925e0b80>, -100), 'cpu': (functools.partial(<function make_cpu_client at 0x150b92602790>, use_tfrt=True), 0), 'tpu_driver': (<function _make_tpu_driver_client at 0x150b81925dc0>, 100), 'cuda': (functools.partial(<function make_gpu_client at 0x150b81925f70>, platform_name='cuda', visible_devices_flag='jax_cuda_visible_devices'), 200), 'rocm': (functools.partial(<function make_gpu_client at 0x150b81925f70>, platform_name='rocm', visible_devices_flag='jax_rocm_visible_devices'), 200), 'tpu': (functools.partial(<function tpu_client_timer_callback at 0x150b81925e50>, timer_secs=60.0), 300), 'plugin': (<function make_plugin_device_client at 0x150b92602af0>, 400)}
maxosmith commented 1 year ago

If you further print out:

from jax._src.config import flags

FLAGS = flags.FLAGS
print(FLAGS.jax_cuda_visible_devices)

from jax._src import distributed

print(distributed.global_state.client)
print(distributed.global_state.service)
print(distributed.global_state.process_id)

Both have the same settings:

[test_gpu/1] all
[test_gpu/1] None
[test_gpu/1] None
[test_gpu/1] 0
[test_gpu/0] all
[test_gpu/0] None
[test_gpu/0] None
[test_gpu/0] 0

As far as I can tell, all of the system's settings are the same during the hand-off to XLA.

nouiz commented 1 year ago

The original issue was about sharing one GPU by multiple job on SLURM. Is this your case? It was clear. If it isn't the case, can you open a new issue?

@hawkinsp SLURM can be configured in many ways to control the GPU. Some old config was doing it via CUDA_VISIBLE_DEVICES. More recent config will do it via cgroups. Using cgroups is better as it enforce the rules, while using env var, the end user can overwrite them.

maxosmith commented 1 year ago

I guess I read "Particularly, I want to launch multiple processes that use JAX and have them share a single allocated GPU. " incorrectly then. I'll open a different issue.

nouiz commented 1 year ago

For the original issue, the issue is that SLURM prevent that by default. This is a normal SLURM behavior. Otherwise, other users could just allocate on your GPUs and use its compute. This could crash your own jobs or kill them.

I'm not a SLURM master, so maybe there is a cleaner way to do this. But you could do what you want by starting 1 SLURM job per GPU. But that jobs is a bash script that dispatch multiple jobs on the GPU is has access to.

@sweetice what do you think of that solution?

nestordemeure commented 1 year ago

I believe I hit the same issue here (a search did not raise this issue until now).

The solution that worked for me was to:

Note that there is an issue with JAX asking for non existing device number (I point to the problematic line in my issue).

nouiz commented 1 year ago

@nestordemeure Can you create a bug specific for your last sentence so that it doesn't get lost? Thanks for the report!