google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.2k stars 2.68k forks source link

Multiple subprocesses cannot see GPU #13687

Closed maxosmith closed 1 year ago

maxosmith commented 1 year ago

Description

I am having difficulty getting Jax to share a GPU backend across subprocesses on a SLURM job. The experienced behavior is that whichever process binds with the GPU first is able to work correctly, and all other processes cannot a GPU backend.

Below is a minimal example. In it I've experimented with the main process and varying amounts of subprocesses attempting to bind with the single GPU. I've tried various permutations of XLA flags for memory preallocation or visible devices.

import os
import time

import jax
import launchpad as lp
import tensorflow as tf
from absl import app
from launchpad.nodes.python.local_multi_processing import PythonProcess

class DeviceTester:
    def test(self):
        print(os.environ)
        print("Backends: ", jax.default_backend())
        print("Devices: ", jax.devices())

    def run(self):
        time.sleep(5)
        lp.stop()

def _build_test_node():
    return lp.CourierNode(DeviceTester)

def main(_):
    # Test independent of Launchpad.
    print("\n\nLocal")
    print(os.environ)
    test = DeviceTester()
    test.test()

    # Test GPU accessibility on Launchpad.
    print("\n\nLaunchpad")

    program = lp.Program(name="experiment")
    handles = []

    with program.group("test_cpu"):
        handles.append(program.add_node(_build_test_node()))

    with program.group("test_gpu"):
        handles.append(program.add_node(_build_test_node()))
        handles.append(program.add_node(_build_test_node()))

    lp.launch(
        program,
        launch_type=lp.LaunchType.LOCAL_MULTI_PROCESSING,
        terminal="current_terminal",
        local_resources={
            "test_cpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "",
                    "JAX_PLATFORM_NAME": "cpu",
                }
            ),
            "test_gpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "0",
                    "XLA_PYTHON_CLIENT_MEM_FRACTION": ".2",
                    "XLA_PYTHON_CLIENT_PREALLOCATE": "false",
                    "JAX_PLATFORM_NAME": "gpu",
                }
            ),
        },
    )

    for handle in handles:
        handle.dereference().test()

if __name__ == "__main__":
    # Provide access to --jax_backend_target and --jax_xla_backend flags.
    jax.config.config_with_absl()
    # Binary should use CPU
    jax.config.update("jax_platform_name", "cpu")
    tf.config.experimental.set_visible_devices([], "GPU")
    app.run(main)
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "launchpad_gpu_test.py", line 80, in <module>
    app.run(main)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "launchpad_gpu_test.py", line 70, in main
    handle.dereference().test()
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/courier/python/client.py", line 52, in inner_function
    raise translate_status(e.status) from e
pybind11_abseil.status.StatusNotOk: Python exception was raised on the server:
Traceback (most recent call last):
  File "launchpad_gpu_test.py", line 14, in test
    print("Backends: ", jax.default_backend())
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 490, in default_backend
    return get_backend(None).platform
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 427, in get_backend
    return _get_backend_uncached(platform)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 413, in _get_backend_uncached
    platform = canonicalize_platform(platform)
  File "/home/mxsmith/.conda/envs/model38/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 294, in canonicalize_platform
    raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu

Both GPU nodes have the same env variables, which are:

[test_gpu/1] environ({'CONDA_SHLVL': '2', 'LD_LIBRARY_PATH': '/home/mxsmith/.conda/envs/model38/lib', 'LS_COLORS': 'no=00:di=34;01:tw=34;01:ow=34;01:fi=00:ln=00:pi=00:so=00:bd=00:cd=00:or=00:mi=00:ex=00:*.sh=31:*.exe=31:*.bat=31', 'CONDA_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/conda', 'SRUN_DEBUG': '3', 'SLURM_STEP_ID': '0', 'SLURM_STEP_GPUS': '0', 'SLURM_NODEID': '0', 'SLURM_TASK_PID': '827752', 'HTTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SSH_CONNECTION': '141.211.21.82 51992 141.211.192.38 22', 'SLURM_PRIO_PROCESS': '0', 'SLURM_CPU_BIND_VERBOSE': 'quiet', 'IEX_TOKEN': 'pk_6cd11420a64b4d5a856ac31281418f38', 'LANG': 'en_US.UTF-8', 'SLURM_SUBMIT_DIR': '/home/mxsmith', 'HISTCONTROL': 'ignoreboth:erasedups', 'HOSTNAME': 'gl1520.arc-ts.umich.edu', 'OLDPWD': '/home/mxsmith', 'SLURM_STEPID': '0', 'SLURM_SRUN_COMM_HOST': '141.211.192.38', 'EDITOR': 'emacs', 'SLURM_DISTRIBUTION': 'cyclic', 'ROCR_VISIBLE_DEVICES': '0', 'CONDA_PREFIX': '/home/mxsmith/.conda/envs/model38', 'SQUEUE_FORMAT': '%.18i %.9P %40j %.8u %.2t %.10M %.20R', 'SLURM_PROCID': '0', 'SLURM_JOB_GID': '99464869', 'SLURM_CPU_BIND': 'quiet,mask_cpu:0x00000001', 'SLURMD_NODENAME': 'gl1520', 'GIT_EDITOR': 'emacs', 'SLURM_TASKS_PER_NODE': '1', 'S_COLORS': 'auto', '_CE_M': '', 'XLA_PYTHON_CLIENT_PREALLOCATE': 'false', 'TF2_BEHAVIOR': '1', 'XDG_SESSION_ID': '11093', 'SLURM_NNODES': '1', 'USER': 'mxsmith', 'SLURM_LAUNCH_NODE_IPADDR': '141.211.192.38', 'CONDA_PREFIX_1': '/sw/pkgs/arc/python3.9-anaconda/2021.11', 'SLURM_STEP_TASKS_PER_NODE': '1', 'MATPLOTLIBRC': '/home/mxsmith/profile/matplotlib', 'FTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'PWD': '/home/mxsmith/projects', 'SSH_ASKPASS': '/usr/libexec/openssh/gnome-ssh-askpass', 'SLURM_JOB_NODELIST': 'gl1520', 'HOME': '/home/mxsmith', 'SLURM_CLUSTER_NAME': 'greatlakes', 'CONDA_PYTHON_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/python', 'SLURM_NODELIST': 'gl1520', 'SLURM_GPUS_ON_NODE': '1', 'SSH_CLIENT': '141.211.21.82 51992 22', 'LMOD_VERSION': '8.6.14', 'SLURM_NTASKS': '1', 'TMUX': '/tmp/tmux-99464869/default,827956,0', 'rsync_proxy': 'proxy.arc-ts.umich.edu:3128', 'SLURM_UMASK': '0002', 'https_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'KRB5CCNAME': 'FILE:/tmp/krb5cc_99464869_TngfOC', 'TF_CPP_MIN_LOG_LEVEL': '1', 'SLURM_JOB_CPUS_PER_NODE': '4', 'BASH_ENV': '/sw/lmod/lmod/init/bash', 'XDG_DATA_DIRS': '/home/mxsmith/.local/share/flatpak/exports/share:/var/lib/flatpak/exports/share:/usr/local/share:/usr/share', 'AUTOJUMP_ERROR_PATH': '/home/mxsmith/.local/share/autojump/errors.log', 'SLURM_TOPOLOGY_ADDR': 'gl1520', 'http_proxy': 'http://proxy.arc-ts.umich.edu:3128/', '_CE_CONDA': '', 'SLURM_WORKING_CLUSTER': 'greatlakes:glctld:6817:9472:109', 'SLURM_STEP_NODELIST': 'gl1520', 'SLURM_JOB_NAME': 'bash', 'SLURM_SRUN_COMM_PORT': '60207', 'TMPDIR': '/tmp', 'LMOD_sys': 'Linux', 'SLURM_JOBID': '45908601', 'JAX_PLATFORM_NAME': 'gpu', 'SLURM_CONF': '/var/spool/slurmd.spool/conf-cache/slurm.conf', 'LMOD_AVAIL_STYLE': 'grouped', 'no_proxy': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'LMOD_ROOT': '/sw/lmod', 'SLURM_JOB_QOS': 'normal', 'SLURM_TOPOLOGY_ADDR_PATTERN': 'node', 'CONDA_PROMPT_MODIFIER': '(model38) ', 'SSH_TTY': '/dev/pts/80', 'NO_PROXY': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'MAIL': '/var/spool/mail/mxsmith', 'HTTPS_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SLURM_CPUS_ON_NODE': '4', 'XLA_PYTHON_CLIENT_MEM_FRACTION': '.2', 'VISUAL': 'emacs', 'SLURM_JOB_NUM_NODES': '1', 'AUTOJUMP_SOURCED': '1', 'SHELL': '/bin/bash', 'TERM': 'xterm-256color', 'SLURM_JOB_UID': '99464869', 'SLURM_JOB_PARTITION': 'spgpu', 'SLURM_PTY_WIN_ROW': '49', 'SLURM_CPU_BIND_LIST': '0x00000001', 'SLURM_JOB_USER': 'mxsmith', 'CUDA_VISIBLE_DEVICES': '0', 'SLURM_PTY_WIN_COL': '97', 'TMUX_PANE': '%3', 'SLURM_NPROCS': '1', 'SHLVL': '3', 'SLURM_SUBMIT_HOST': 'gl-login1.arc-ts.umich.edu', 'SLURM_JOB_ACCOUNT': 'wellman0', 'MANPATH': '/sw/lmod/lmod/share/man:/usr/local/share/man:/usr/share/man:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man:/opt/slurm/share/man/:/opt/TurboVNC/man/:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man', 'SLURM_STEP_LAUNCHER_PORT': '60207', 'MODULEPATH': '/sw/lmod/lmod/modulefiles/Core:/sw/modules/Core:/sw/modules/Collections', 'SLURM_PTY_PORT': '60206', 'SLURM_GTIDS': '0', 'LOGNAME': 'mxsmith', 'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/99464869/bus', 'XDG_RUNTIME_DIR': '/run/user/99464869', 'MODULEPATH_ROOT': '/sw/modules', 'LMOD_PACKAGE_PATH': '/sw/lmod', 'PATH': '/home/mxsmith/software/bin:/home/mxsmith/software/bin:/home/mxsmith/.conda/envs/model38/bin:/sw/pkgs/arc/python3.9-anaconda/2021.11/condabin:/home/mxsmith/software/bin:/opt/TurboVNC/bin:/opt/slurm/bin:/opt/slurm/sbin:/sw/pkgs/arc/usertools/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/usr/lpp/mmfs/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin', 'SLURM_JOB_ID': '45908601', 'SLURM_CPU_BIND_TYPE': 'mask_cpu:', 'SLURM_STEP_NUM_TASKS': '1', 'MODULESHOME': '/sw/lmod/lmod', 'CONDA_DEFAULT_ENV': 'model38', 'LMOD_SETTARG_FULL_SUPPORT': 'no', 'HISTSIZE': '1000', 'LMOD_PKG': '/sw/lmod/lmod', 'CLUSTER_NAME': 'greatlakes', 'SLURM_STEP_NUM_NODES': '1', 'ftp_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'RSYNC_PROXY': 'proxy.arc-ts.umich.edu:3128', 'LMOD_CMD': '/sw/lmod/lmod/libexec/lmod', 'SLURM_LOCALID': '0', 'GPU_DEVICE_ORDINAL': '0', 'LESSOPEN': '||/usr/bin/lesspipe.sh %s', 'LMOD_DIR': '/sw/lmod/lmod/libexec', 'BASH_FUNC_module%%': '() {  local __lmod_my_status;\n local __lmod_sh_dbg;\n if [ -z "${LMOD_SH_DBG_ON+x}" ]; then\n case "$-" in \n *v*x*)\n __lmod_sh_dbg=\'vx\'\n ;;\n *v*)\n __lmod_sh_dbg=\'v\'\n ;;\n *x*)\n __lmod_sh_dbg=\'x\'\n ;;\n esac;\n fi;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n set +$__lmod_sh_dbg;\n echo "Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod\'s output" 1>&2;\n fi;\n eval "$($LMOD_CMD bash "$@")" && eval $(${LMOD_SETTARG_CMD:-:} -s sh);\n __lmod_my_status=$?;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n echo "Shell debugging restarted" 1>&2;\n set -$__lmod_sh_dbg;\n fi;\n return $__lmod_my_status\n}', 'BASH_FUNC_ml%%': '() {  eval "$($LMOD_DIR/ml_cmd "$@")"\n}'})
[test_gpu/0] environ({'CONDA_SHLVL': '2', 'LD_LIBRARY_PATH': '/home/mxsmith/.conda/envs/model38/lib', 'LS_COLORS': 'no=00:di=34;01:tw=34;01:ow=34;01:fi=00:ln=00:pi=00:so=00:bd=00:cd=00:or=00:mi=00:ex=00:*.sh=31:*.exe=31:*.bat=31', 'CONDA_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/conda', 'SRUN_DEBUG': '3', 'SLURM_STEP_ID': '0', 'SLURM_STEP_GPUS': '0', 'SLURM_NODEID': '0', 'SLURM_TASK_PID': '827752', 'HTTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SSH_CONNECTION': '141.211.21.82 51992 141.211.192.38 22', 'SLURM_PRIO_PROCESS': '0', 'SLURM_CPU_BIND_VERBOSE': 'quiet', 'IEX_TOKEN': 'pk_6cd11420a64b4d5a856ac31281418f38', 'LANG': 'en_US.UTF-8', 'SLURM_SUBMIT_DIR': '/home/mxsmith', 'HISTCONTROL': 'ignoreboth:erasedups', 'HOSTNAME': 'gl1520.arc-ts.umich.edu', 'OLDPWD': '/home/mxsmith', 'SLURM_STEPID': '0', 'SLURM_SRUN_COMM_HOST': '141.211.192.38', 'EDITOR': 'emacs', 'SLURM_DISTRIBUTION': 'cyclic', 'ROCR_VISIBLE_DEVICES': '0', 'CONDA_PREFIX': '/home/mxsmith/.conda/envs/model38', 'SQUEUE_FORMAT': '%.18i %.9P %40j %.8u %.2t %.10M %.20R', 'SLURM_PROCID': '0', 'SLURM_JOB_GID': '99464869', 'SLURM_CPU_BIND': 'quiet,mask_cpu:0x00000001', 'SLURMD_NODENAME': 'gl1520', 'GIT_EDITOR': 'emacs', 'SLURM_TASKS_PER_NODE': '1', 'S_COLORS': 'auto', '_CE_M': '', 'XLA_PYTHON_CLIENT_PREALLOCATE': 'false', 'TF2_BEHAVIOR': '1', 'XDG_SESSION_ID': '11093', 'SLURM_NNODES': '1', 'USER': 'mxsmith', 'SLURM_LAUNCH_NODE_IPADDR': '141.211.192.38', 'CONDA_PREFIX_1': '/sw/pkgs/arc/python3.9-anaconda/2021.11', 'SLURM_STEP_TASKS_PER_NODE': '1', 'MATPLOTLIBRC': '/home/mxsmith/profile/matplotlib', 'FTP_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'PWD': '/home/mxsmith/projects', 'SSH_ASKPASS': '/usr/libexec/openssh/gnome-ssh-askpass', 'SLURM_JOB_NODELIST': 'gl1520', 'HOME': '/home/mxsmith', 'SLURM_CLUSTER_NAME': 'greatlakes', 'CONDA_PYTHON_EXE': '/sw/pkgs/arc/python3.9-anaconda/2021.11/bin/python', 'SLURM_NODELIST': 'gl1520', 'SLURM_GPUS_ON_NODE': '1', 'SSH_CLIENT': '141.211.21.82 51992 22', 'LMOD_VERSION': '8.6.14', 'SLURM_NTASKS': '1', 'TMUX': '/tmp/tmux-99464869/default,827956,0', 'rsync_proxy': 'proxy.arc-ts.umich.edu:3128', 'SLURM_UMASK': '0002', 'https_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'KRB5CCNAME': 'FILE:/tmp/krb5cc_99464869_TngfOC', 'TF_CPP_MIN_LOG_LEVEL': '1', 'SLURM_JOB_CPUS_PER_NODE': '4', 'BASH_ENV': '/sw/lmod/lmod/init/bash', 'XDG_DATA_DIRS': '/home/mxsmith/.local/share/flatpak/exports/share:/var/lib/flatpak/exports/share:/usr/local/share:/usr/share', 'AUTOJUMP_ERROR_PATH': '/home/mxsmith/.local/share/autojump/errors.log', 'SLURM_TOPOLOGY_ADDR': 'gl1520', 'http_proxy': 'http://proxy.arc-ts.umich.edu:3128/', '_CE_CONDA': '', 'SLURM_WORKING_CLUSTER': 'greatlakes:glctld:6817:9472:109', 'SLURM_STEP_NODELIST': 'gl1520', 'SLURM_JOB_NAME': 'bash', 'SLURM_SRUN_COMM_PORT': '60207', 'TMPDIR': '/tmp', 'LMOD_sys': 'Linux', 'SLURM_JOBID': '45908601', 'JAX_PLATFORM_NAME': 'gpu', 'SLURM_CONF': '/var/spool/slurmd.spool/conf-cache/slurm.conf', 'LMOD_AVAIL_STYLE': 'grouped', 'no_proxy': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'LMOD_ROOT': '/sw/lmod', 'SLURM_JOB_QOS': 'normal', 'SLURM_TOPOLOGY_ADDR_PATTERN': 'node', 'CONDA_PROMPT_MODIFIER': '(model38) ', 'SSH_TTY': '/dev/pts/80', 'NO_PROXY': 'localhost,127.0.0.1,.localdomain,.umich.edu', 'MAIL': '/var/spool/mail/mxsmith', 'HTTPS_PROXY': 'http://proxy.arc-ts.umich.edu:3128/', 'SLURM_CPUS_ON_NODE': '4', 'XLA_PYTHON_CLIENT_MEM_FRACTION': '.2', 'VISUAL': 'emacs', 'SLURM_JOB_NUM_NODES': '1', 'AUTOJUMP_SOURCED': '1', 'SHELL': '/bin/bash', 'TERM': 'xterm-256color', 'SLURM_JOB_UID': '99464869', 'SLURM_JOB_PARTITION': 'spgpu', 'SLURM_PTY_WIN_ROW': '49', 'SLURM_CPU_BIND_LIST': '0x00000001', 'SLURM_JOB_USER': 'mxsmith', 'CUDA_VISIBLE_DEVICES': '0', 'SLURM_PTY_WIN_COL': '97', 'TMUX_PANE': '%3', 'SLURM_NPROCS': '1', 'SHLVL': '3', 'SLURM_SUBMIT_HOST': 'gl-login1.arc-ts.umich.edu', 'SLURM_JOB_ACCOUNT': 'wellman0', 'MANPATH': '/sw/lmod/lmod/share/man:/usr/local/share/man:/usr/share/man:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man:/opt/slurm/share/man/:/opt/TurboVNC/man/:/opt/ddn/ime/share/man:/opt/ddn/ime/share/man', 'SLURM_STEP_LAUNCHER_PORT': '60207', 'MODULEPATH': '/sw/lmod/lmod/modulefiles/Core:/sw/modules/Core:/sw/modules/Collections', 'SLURM_PTY_PORT': '60206', 'SLURM_GTIDS': '0', 'LOGNAME': 'mxsmith', 'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/99464869/bus', 'XDG_RUNTIME_DIR': '/run/user/99464869', 'MODULEPATH_ROOT': '/sw/modules', 'LMOD_PACKAGE_PATH': '/sw/lmod', 'PATH': '/home/mxsmith/software/bin:/home/mxsmith/software/bin:/home/mxsmith/.conda/envs/model38/bin:/sw/pkgs/arc/python3.9-anaconda/2021.11/condabin:/home/mxsmith/software/bin:/opt/TurboVNC/bin:/opt/slurm/bin:/opt/slurm/sbin:/sw/pkgs/arc/usertools/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/usr/lpp/mmfs/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/opt/ddn/ime/bin:/home/mxsmith/anaconda/bin:/home/mxsmith/software/gambit-15.1.1:/home/mxsmith/.local/bin:/home/mxsmith/bin', 'SLURM_JOB_ID': '45908601', 'SLURM_CPU_BIND_TYPE': 'mask_cpu:', 'SLURM_STEP_NUM_TASKS': '1', 'MODULESHOME': '/sw/lmod/lmod', 'CONDA_DEFAULT_ENV': 'model38', 'LMOD_SETTARG_FULL_SUPPORT': 'no', 'HISTSIZE': '1000', 'LMOD_PKG': '/sw/lmod/lmod', 'CLUSTER_NAME': 'greatlakes', 'SLURM_STEP_NUM_NODES': '1', 'ftp_proxy': 'http://proxy.arc-ts.umich.edu:3128/', 'RSYNC_PROXY': 'proxy.arc-ts.umich.edu:3128', 'LMOD_CMD': '/sw/lmod/lmod/libexec/lmod', 'SLURM_LOCALID': '0', 'GPU_DEVICE_ORDINAL': '0', 'LESSOPEN': '||/usr/bin/lesspipe.sh %s', 'LMOD_DIR': '/sw/lmod/lmod/libexec', 'BASH_FUNC_module%%': '() {  local __lmod_my_status;\n local __lmod_sh_dbg;\n if [ -z "${LMOD_SH_DBG_ON+x}" ]; then\n case "$-" in \n *v*x*)\n __lmod_sh_dbg=\'vx\'\n ;;\n *v*)\n __lmod_sh_dbg=\'v\'\n ;;\n *x*)\n __lmod_sh_dbg=\'x\'\n ;;\n esac;\n fi;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n set +$__lmod_sh_dbg;\n echo "Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for Lmod\'s output" 1>&2;\n fi;\n eval "$($LMOD_CMD bash "$@")" && eval $(${LMOD_SETTARG_CMD:-:} -s sh);\n __lmod_my_status=$?;\n if [ -n "${__lmod_sh_dbg:-}" ]; then\n echo "Shell debugging restarted" 1>&2;\n set -$__lmod_sh_dbg;\n fi;\n return $__lmod_my_status\n}', 'BASH_FUNC_ml%%': '() {  eval "$($LMOD_DIR/ml_cmd "$@")"\n}'})
tensorflow==2.8.3
jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1671027717961/work
jaxlib==0.4.1+cuda11.cudnn86
dm-launchpad==0.5.2p

Additional debugging

If you further print out:

from jax._src.config import flags

FLAGS = flags.FLAGS
print(FLAGS.jax_cuda_visible_devices)

from jax._src import distributed

print(distributed.global_state.client)
print(distributed.global_state.service)
print(distributed.global_state.process_id)

Both have the same settings:

[test_gpu/1] all
[test_gpu/1] None
[test_gpu/1] None
[test_gpu/1] 0
[test_gpu/0] all
[test_gpu/0] None
[test_gpu/0] None
[test_gpu/0] 0

As far as I can tell, all of the system's settings are the same during the hand-off to XLA.

If you add this to the nodes you can see that cuda backend is missing from the platform processing

from jax._src.lib import xla_bridge
print("Backends: ", xla_bridge.backends())
[test_gpu/0] Backends:  {'interpreter': <jaxlib.xla_extension.Client object at 0x14b25b27ac30>, 'cpu': <jaxlib.xla_extension.Client object at 0x14b24d95b1b0>, 'cuda': <jaxlib.xla_extension.Client object at 0x14b24c2f7570>}
[test_gpu/1] Backends:  {'interpreter': <jaxlib.xla_extension.Client object at 0x14b3e2dc0f30>, 'cpu': <jaxlib.xla_extension.Client object at 0x14b3e2dc09f0>}

If you also add:

from jax.config import config
print(config.jax_platforms)

They're both None, prompting all of the backend factories to be run.

[test_gpu/1] {'interpreter': (<function make_interpreter_client at 0x145e83579b80>, -100), 'cpu': (functools.partial(<function make_cpu_client at 0x145e8359c790>, use_tfrt=True), 0), 'tpu_driver': (<function _make_tpu_driver_client at 0x145e728bfdc0>, 100), 'cuda': (functools.partial(<function make_gpu_client at 0x145e728bff70>, platform_name='cuda', visible_devices_flag='jax_cuda_visible_devices'), 200), 'rocm': (functools.partial(<function make_gpu_client at 0x145e728bff70>, platform_name='rocm', visible_devices_flag='jax_rocm_visible_devices'), 200), 'tpu': (functools.partial(<function tpu_client_timer_callback at 0x145e728bfe50>, timer_secs=60.0), 300), 'plugin': (<function make_plugin_device_client at 0x145e8359caf0>, 400)}
[test_gpu/0] {'interpreter': (<function make_interpreter_client at 0x150b925e0b80>, -100), 'cpu': (functools.partial(<function make_cpu_client at 0x150b92602790>, use_tfrt=True), 0), 'tpu_driver': (<function _make_tpu_driver_client at 0x150b81925dc0>, 100), 'cuda': (functools.partial(<function make_gpu_client at 0x150b81925f70>, platform_name='cuda', visible_devices_flag='jax_cuda_visible_devices'), 200), 'rocm': (functools.partial(<function make_gpu_client at 0x150b81925f70>, platform_name='rocm', visible_devices_flag='jax_rocm_visible_devices'), 200), 'tpu': (functools.partial(<function tpu_client_timer_callback at 0x150b81925e50>, timer_secs=60.0), 300), 'plugin': (<function make_plugin_device_client at 0x150b92602af0>, 400)}

What jax/jaxlib version are you using?

jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1671027717961/work 
jaxlib==0.4.1+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Red Hat Enterprise Linux 8.4 (Ootpa)

NVIDIA GPU info

NVIDIA-SMI 510.73.08 Driver Version: 510.73.08 CUDA Version: 11.6
NVIDIA A40

nouiz commented 1 year ago

My guess is that this is an issue with the SLURM config or how you call it. What is the slurm command that you use?

nvidia-smi show all the GPUs visibles. It doesn't mean you have access to them from memory.

nouiz commented 1 year ago

Did you see this documentation: https://jax.readthedocs.io/en/latest/multi_process.html ?

maxosmith commented 1 year ago

My SLURM command was: srun --pty --gres=gpu:1 --cpus-per-gpu=4 --mem-per-cpu=10g --time=0-01:00 /bin/bash

I did miss that document, I'll give it a pass now, cheers!

maxosmith commented 1 year ago

Ah, that document is about pmap, I'm not trying to distribute the workload in that nature. I'm trying to have a learner node have a GPU, while actor nodes do not need said GPU.

nouiz commented 1 year ago

If you request only 1 node with 1 GPUs and many CPUs, you can create a bash script that dispatch like this:

CUDA_VISIBLE_DIVICES= python actor.py & # As many time as needed
python learner.py

The first line will hide the GPU from the process. This work for all software, not just JAX.

maxosmith commented 1 year ago

Thanks for the reply.

In my example script that I provided, I'm dispatching two PythonProcess that use this environment variable setting:

            "test_cpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "",
                    "JAX_PLATFORM_NAME": "cpu",
                }
            ),
            "test_gpu": PythonProcess(
                env={
                    "CUDA_VISIBLE_DEVICES": "0",
                    "XLA_PYTHON_CLIENT_MEM_FRACTION": ".2",
                    "XLA_PYTHON_CLIENT_PREALLOCATE": "false",
                    "JAX_PLATFORM_NAME": "gpu",
                }
            ),

However, if I try and spawn two processes that can both see device 0, only the first process is seeing the device. Does that make sense?

nouiz commented 1 year ago

| However, if I try and spawn two processes that can both see device 0, only the first process is seeing the device. Does that make sense?

I think the issue isn't with JAX. It is probably related to your scheduler. I suppose if you spawn two test_gpu process and make sure they are on the same node. Can you print this in both process? print(os.environ.get('CUDA_VISIBLE_DEVICES'))

I do not know launchpad. So I can't help much here. Why do you use that? Can you give me the full output of nvidia-smi? It is possible that the GPU is configured to be usable by only 1 process. Sometimes clusters are configured with that setup. This could explain your issue.

maxosmith commented 1 year ago
$python launchpad_gpu_test.py

Local
0

Launchpad
[test_cpu/0] I1220 18:34:44.112143 23450219082752 courier_utils.py:120] Binding: run
[test_gpu/0] I1220 18:34:44.113588 22711702426624 courier_utils.py:120] Binding: run
[test_gpu/1] I1220 18:34:44.114040 23041964385280 courier_utils.py:120] Binding: run
[test_cpu/0] I1220 18:34:44.114264 23450219082752 courier_utils.py:120] Binding: test
[test_gpu/0] I1220 18:34:44.114415 22711702426624 courier_utils.py:120] Binding: test
[test_gpu/1] I1220 18:34:44.114561 23041964385280 courier_utils.py:120] Binding: test
[test_gpu/1] 0
[test_gpu/0] 0
[test_cpu/0]

Works as expected. GPU nodes print "0", and the local driver program also prints "0". The restricted node does not see the device.

However, if I add print("Devices: ", jax.devices()) right after all of the print(os.environ.get('CUDA_VISIBLE_DEVICES')), all of the processes correctly print their environment variables but all use CPUs.

If I prevent the driver program from binding to the GPU, the processes outputs are:

...
[test_gpu/1] 0
[test_gpu/1] Devices:  [CpuDevice(id=0)]
[test_gpu/0] 0
[test_gpu/0] Devices:  [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
[test_cpu/0]
[test_cpu/0] Devices:  [CpuDevice(id=0)]

I'll submit a ticket to the supercomputer team to see if the suggested limitation is in place.

maxosmith commented 1 year ago

Thanks for being so interactive Frédéric, I was surprised to see your name pop up here, but I remember you from our rare interactions back at Pavillon André-Aisenstadt in what 2016? :)

maxosmith commented 1 year ago

I've heard back from the supercomputer team and the devices are indeed set to "exclusive process" mode, which is undoubtedly the issue I'm encountering. Thanks again for all the help.