google-research / t5x

Apache License 2.0
2.65k stars 302 forks source link

How to run t5x on multi-node GPUs? #832

Open lintangsutawika opened 2 years ago

lintangsutawika commented 2 years ago

I'm exploring how to use t5x in a multi-node GPU setting. I'm using SLURM with a singularity container to execute the training script.

#!/bin/bash
#SBATCH --partition=gpu
#SBATCH --job-name=lintang-t5x-multinode
#SBATCH --nodes=2
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=1
#SBATCH --output=logs/%x_%j.out
#SBATCH --cpus-per-task=32
#SBATCH --exclusive
#SBATCH --requeue
#SBATCH --wait-all-nodes=1
#SBATCH --comment=ProjectName

# Cache Directories
export SINGULARITY_CACHEDIR=...
export BASE_DIR="..."

export PROJECT_DIR=${BASE_DIR}"..."
export MODEL_DIR="..."
export TFDS_DATA_DIR="..."

# directory where the T5X repo is cloned.
export T5X_DIR=${BASE_DIR}"..."
export PYTHONPATH=${PROJECT_DIR}

# export TF_XLA_FLAGS="--tf_xla_auto_jit=2"
export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1" # Hacky and don't want

singularity exec \
    --nv --bind /fsx:/fsx t5x-env.sif \
    python ${T5X_DIR}/t5x/train.py \
        --gin_search_paths=${PROJECT_DIR} \
        --gin_file="config-base.gin" \
        --gin.MODEL_DIR=\"${MODEL_DIR}\" \
        --gin.USE_CACHED_TASKS=False \
        --alsologtostderr \
        --multiprocess_gpu \
        --coordinator_address="${SLURM_LAUNCH_NODE_IPADDR}:29500" \
        --process_count "${SLURM_NPROCS}" \
        --process_index "${SLURM_PROCID}"

But this doesn't seem to work.

2022-10-12 15:03:54.406737: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-12 15:03:55.491066: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/hcoll/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ompi/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/nccl_rdma_sharp_plugin/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/sharp/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib/ucx:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/11.7/nccl/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/compilers/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/cuda/11.7/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/math_libs/11.7/lib64:::/.singularity.d/libs
2022-10-12 15:03:55.492223: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/hcoll/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ompi/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/nccl_rdma_sharp_plugin/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/sharp/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib/ucx:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/11.7/nccl/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/compilers/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/cuda/11.7/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/math_libs/11.7/lib64:::/.singularity.d/libs
2022-10-12 15:03:55.492252: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2022-10-12 15:03:59.741770: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/hcoll/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ompi/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/nccl_rdma_sharp_plugin/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/sharp/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/hpcx/hpcx-2.11/ucx/mt/lib/ucx:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/comm_libs/11.7/nccl/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/compilers/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/cuda/11.7/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.7/math_libs/11.7/lib64:::/.singularity.d/libs
2022-10-12 15:03:59.743346: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I1012 15:03:59.743987 140385279873024 train.py:725] Initializing distributed system for multi-host GPU:
  coordinator_address: :29500
  process_count: 2
  process_index: 0
I1012 15:03:59.744206 140385279873024 distributed.py:58] JAX distributed initialized with visible devices: 0
I1012 15:03:59.744523 140385279873024 distributed.py:67] Starting JAX distributed service on :29500
E1012 15:03:59.747199977    8449 server_chttp2.cc:40]        {"created":"@1665587039.747188280","description":"Name or service not known","errno":-2,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/resolve_address_posix.cc","file_line":108,"os_error":"Name or service not known","syscall":"getaddrinfo","target_address":":29500"}
Fatal Python error: Segmentation fault

Thread 0x00007fadfebd1000 (most recent call first):
  File "/usr/local/lib/python3.8/site-packages/jax/_src/distributed.py", line 68 in initialize
  File "/usr/local/lib/python3.8/site-packages/jax/_src/distributed.py", line 159 in initialize
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 730 in _main
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 710 in main
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 254 in _run_main
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 308 in run
  File "/fsx/lintangsutawika/t5x/t5x/gin_utils.py", line 107 in run
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 750 in <module>
/var/spool/slurmd/job08329/slurm_script: line 53:  8420 Segmentation fault      singularity exec --nv --bind /fsx:/fsx t5x-env.sif python ${T5X_DIR}/t5x/train.py --gin_search_paths=${PROJECT_DIR} --gin_file="config-base.gin" --gin.MODEL_DIR=\"${MODEL_DIR}\" --gin.USE_CACHED_TASKS=False --alsologtostderr --multiprocess_gpu --coordinator_address="${SLURM_LAUNCH_NODE_IPADDR}:29500" --process_count "${SLURM_NPROCS}" --process_index "${SLURM_PROCID}"

Another method I tried is try to launch two process with hard-coded process_index (I did this in an interactive shell)

singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
    python ${T5X_DIR}/t5x/train.py \
        --gin_search_paths=${PROJECT_DIR} \
        --gin_file="config-base.gin" \
        --gin.MODEL_DIR=\"${MODEL_DIR}\" \
        --gin.USE_CACHED_TASKS=False \
        --alsologtostderr \
        --multiprocess_gpu \
        --coordinator_address="${SLURM_LAUNCH_NODE_IPADDR}:29500" \
        --process_count "${SLURM_NPROCS}" \
        --process_index 0
singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
    python ${T5X_DIR}/t5x/train.py \
        --gin_search_paths=${PROJECT_DIR} \
        --gin_file="config-base.gin" \
        --gin.MODEL_DIR=\"${MODEL_DIR}\" \
        --gin.USE_CACHED_TASKS=False \
        --alsologtostderr \
        --multiprocess_gpu \
        --coordinator_address="${SLURM_LAUNCH_NODE_IPADDR}:29500" \
        --process_count "${SLURM_NPROCS}" \
        --process_index 1

process_index 1 seems to working as intended

I1012 16:17:41.957645 140701990588416 train.py:725] Initializing distributed system for multi-host GPU:
  coordinator_address: 172.31.37.44:29500
  process_count: 2
  process_index: 1
I1012 16:17:41.957846 140701990588416 distributed.py:58] JAX distributed initialized with visible devices: 0
I1012 16:17:41.986860 140701990588416 distributed.py:78] Connecting to JAX distributed service on 172.31.37.44:29500

but process_index 0 fails.

I1012 16:18:04.230459 140621940224000 train.py:725] Initializing distributed system for multi-host GPU:
  coordinator_address: 172.31.37.44:29500
  process_count: 2
  process_index: 0
I1012 16:18:04.230640 140621940224000 distributed.py:58] JAX distributed initialized with visible devices: 0
I1012 16:18:04.230890 140621940224000 distributed.py:67] Starting JAX distributed service on 172.31.37.44:29500
E1012 16:18:04.231958535    6069 server_chttp2.cc:40]        {"created":"@1665591484.231935845","description":"No address added out of total 1 resolved","file":"external/com_github_grpc_grpc/src/core/ext/transport/chttp2/server/chttp2_server.cc","file_line":395,"referenced_errors":[{"created":"@1665591484.231933760","description":"Unable to configure socket","fd":30,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":215,"referenced_errors":[{"created":"@1665591484.231931381","description":"Cannot assign requested address","errno":99,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":189,"os_error":"Cannot assign requested address","syscall":"bind"}]}]}
Fatal Python error: Segmentation fault

Thread 0x00007fe518cb7000 (most recent call first):
  File "/usr/local/lib/python3.8/site-packages/jax/_src/distributed.py", line 68 in initialize
  File "/usr/local/lib/python3.8/site-packages/jax/_src/distributed.py", line 159 in initialize
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 730 in _main
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 710 in main
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 254 in _run_main
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 308 in run
  File "/fsx/lintangsutawika/t5x/t5x/gin_utils.py", line 107 in run
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 750 in <module>
Segmentation fault
sudhakarsingh27 commented 2 years ago

Looks like the coordinator address isn't set properly.

coordinator_address: :29500
process_count: 2
process_index: 0

Could you recheck ${SLURM_LAUNCH_NODE_IPADDR} value in your env? (it's getting picked correctly for process_index=1 but not for process_index=0)

lintangsutawika commented 2 years ago

Yeah, looks like SLURM_LAUNCH_NODE_IPADDR isn't the machine's actual address. I tested with "127.0.0.1:29500" as the coordinator_address and running this actually doesn't stop at segmentation fault

singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
    python ${T5X_DIR}/t5x/train.py \
        --gin_search_paths=${PROJECT_DIR} \
        --gin_file="config-base.gin" \
        --gin.MODEL_DIR=\"${MODEL_DIR}\" \
        --gin.USE_CACHED_TASKS=False \
        --alsologtostderr \
        --multiprocess_gpu \
        --coordinator_address="${ADDR}" \
        --process_count "${SLURM_NPROCS}" \
        --process_index 0 \
& \
singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
    python ${T5X_DIR}/t5x/train.py \
        --gin_search_paths=${PROJECT_DIR} \
        --gin_file="config-base.gin" \
        --gin.MODEL_DIR=\"${MODEL_DIR}\" \
        --gin.USE_CACHED_TASKS=False \
        --alsologtostderr \
        --multiprocess_gpu \
        --coordinator_address="${ADDR}" \
        --process_count "${SLURM_NPROCS}" \
        --process_index 1

However, now I see that the script only detects 1 GPU per process? Also, a new error.


/usr/local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:556: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
I1012 18:00:11.249953 140565072195584 partitioning.py:331] global_mesh axis_names: ('data', 'model')
I1012 18:00:11.250089 140565072195584 partitioning.py:332] global_mesh devices: [[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=1)]]
2022-10-12 18:00:12.262847: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 1 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
2022-10-12 18:00:12.272575: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
Traceback (most recent call last):
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 749, in <module>
    gin_utils.run(main)
  File "/fsx/lintangsutawika/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 710, in main
    _main(argv)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 745, in _main
    train_using_gin()
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 251, in train
    train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1371, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1385, in get_dataset_inner
    multihost_assert_equal(
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 566, in multihost_assert_equal
    multihost_utils.assert_equal(input_tree, fail_message)
  File "/usr/local/lib/python3.8/site-packages/jax/experimental/multihost_utils.py", line 169, in assert_equal
    expected = broadcast_one_to_all(in_tree)
  File "/usr/local/lib/python3.8/site-packages/jax/experimental/multihost_utils.py", line 75, in broadcast_one_to_all
    in_tree = jax.device_get(_psum(in_tree))
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
  In call to configurable 'train' (<function train at 0x7f2f3dd08670>)```
sudhakarsingh27 commented 2 years ago

Could you append and env variable NCCL_DEBUG=INFO before the python command and run that, that'd let us see some debug info from NCCL

sudhakarsingh27 commented 2 years ago

However, now I see that the script only detects 1 GPU per process?

That's because recently, jax.distributed.initialize was modified to run 1 GPU per process as the default setting. To let the process see all the GPUs in the node, you could pass an additional argument local_device_ids=list(range(num_gpus)] to jax.distributed.initialize.

lintangsutawika commented 2 years ago

Looks like it's an out of memory issue? I reduced the batch size with size of 1 but the problem still persists.

gpu-st-p4d-24xlarge-25:15647:15647 [0] external/nccl_archive/src/enqueue.cc:128 NCCL WARN Cuda failure 'out of memory'
gpu-st-p4d-24xlarge-25:15647:15647 [0] NCCL INFO Bootstrap : Using eth0:172.31.224.66<0>
gpu-st-p4d-24xlarge-25:15647:15647 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
gpu-st-p4d-24xlarge-25:15647:15647 [0] NCCL INFO cudaDriverVersion 11040
NCCL version 2.13.4+cudaCUDA_MAJOR.CUDA_MINOR

gpu-st-p4d-24xlarge-25:15647:15647 [0] external/nccl_archive/src/init.cc:1075 NCCL WARN Cuda failure 'out of memory'
gpu-st-p4d-24xlarge-25:15647:15647 [0] NCCL INFO external/nccl_archive/src/init.cc:1106 -> 1
2022-10-13 04:16:09.849660: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
Traceback (most recent call last):
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 754, in <module>
    gin_utils.run(main)
  File "/fsx/lintangsutawika/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 710, in main
    _main(argv)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 750, in _main
    train_using_gin()
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 251, in train
    train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1371, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1385, in get_dataset_inner
    multihost_assert_equal(
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 566, in multihost_assert_equal
    multihost_utils.assert_equal(input_tree, fail_message)
  File "/usr/local/lib/python3.8/site-packages/jax/experimental/multihost_utils.py", line 175, in assert_equal
    expected = broadcast_one_to_all(in_tree)
  File "/usr/local/lib/python3.8/site-packages/jax/experimental/multihost_utils.py", line 75, in broadcast_one_to_all
    in_tree = jax.device_get(_psum(in_tree))
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
  In call to configurable 'train' (<function train at 0x7fd683eab9d0>)
sudhakarsingh27 commented 2 years ago

Ack. This looks like a familiar issue. Could you try a few WAR meanwhile we figure out a fix for this:

  1. Add the following code to your t5x/train.py file:
    ...
    import tensorflow as tf
    tf.config.experimental.set_visible_devices([], "GPU")
    ...
  2. If 1.) doesn't work, try reducing the memory required by JAX using XLA_PYTHON_CLIENT_MEM_FRACTION=.XX with a value less than 0.9 (which is default).

Btw, which GPU are you running T5x on?

lintangsutawika commented 2 years ago

I'm using A100 (40GB). Added 1, but error still persists. Step 2 seems to alleviate the previous issue but is stopped at a new error.

gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO cudaDriverVersion 11080
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO Bootstrap : Using eth0:172.31.224.130<0>
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/Plugin: Failed to find ncclNetPlugin_v6 symbol.
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin (v5)
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v6 symbol.
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/Plugin: Loaded coll plugin SHARP (v5)
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO P2P plugin IBext
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/IB : No device found.
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/IB : No device found.
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO NET/Socket : Using [0]eth0:172.31.224.130<0> [1]eth1:172.31.236.189<0> [2]eth2:172.31.232.164<0> [3]eth3:172.31
.236.136<0>
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO Using network Socket

gpu-st-p4d-24xlarge-78:41557:41557 [0] external/nccl_archive/src/init.cc:511 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 101c0
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO external/nccl_archive/src/init.cc:1045 -> 5
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO external/nccl_archive/src/init.cc:1091 -> 5
gpu-st-p4d-24xlarge-78:41557:41557 [0] NCCL INFO external/nccl_archive/src/init.cc:1106 -> 5
2022-10-13 05:20:09.232905: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 1 failed: INTERNAL$ external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: invalid u$age
Traceback (most recent call last):
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 755, in <module>
    gin_utils.run(main)
  File "/fsx/lintangsutawika/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 711, in main
    _main(argv)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 751, in _main
    train_using_gin()
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/fsx/home-lintangsutawika/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/fsx/lintangsutawika/architecture-objective/t5x/train.py", line 252, in train
    train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1371, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 1385, in get_dataset_inner
    multihost_assert_equal(
  File "/fsx/lintangsutawika/t5x/t5x/utils.py", line 566, in multihost_assert_equal
    multihost_utils.assert_equal(input_tree, fail_message)
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/multihost_utils.py", line 175, in assert_equal
    expected = broadcast_one_to_all(in_tree)
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/multihost_utils.py", line 75, in broadcast_one_to_all
    in_tree = jax.device_get(_psum(in_tree))
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:266: NCCL operation ncclCommInitRank(c
omm.get(), nranks, id, rank) failed: invalid usage
sudhakarsingh27 commented 2 years ago

To confirm, you're still running 2 processes on 2 nodes with ntasks-per-node=1 and without any local_device_ids argument to jax.distributed.initialize?

Btw, this error is new! Could you share the gin config that you're using? It'd help me repro this on my end.

Next quick check - could you run the code with nodes=1 and ntasks-per-node=8 (let's see multiprocess behaviour on a single node)?

Also, OOC, you seem to be using slurm and then launching two processes on the same node in this comment above? If you use an srun command, that'd launch the command on two separate nodes (processes). Is singularity exec is doing something similar?

lintangsutawika commented 2 years ago

The new error seems to be a matter of running the different process on the same set of GPUs.

Running this works (two processed on the same node, but different set of GPUs). So the issue is likely as you said, on my side on how to properly assign each process to the correct node.

export ADDR="127.0.0.1:29500"
CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_DEBUG=INFO \
    singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
        python ${T5X_DIR}/t5x/train.py \
            --gin_search_paths=${PROJECT_DIR} \
            --gin_file="config-base.gin" \
            --gin.MODEL_DIR=\"${MODEL_DIR}\" \
            --gin.USE_CACHED_TASKS=False \
            --alsologtostderr \
            --multiprocess_gpu \
            --coordinator_address="${ADDR}" \
            --process_count 2 \
            --process_index 0 \
    & \
CUDA_VISIBLE_DEVICES=4,5,6,7 NCCL_DEBUG=INFO \
    singularity exec --nv --bind /fsx:/fsx /fsx/lintangsutawika/t5x-env.sif \
        python ${T5X_DIR}/t5x/train.py \
            --gin_search_paths=${PROJECT_DIR} \
            --gin_file="config-base.gin" \
            --gin.MODEL_DIR=\"${MODEL_DIR}\" \
            --gin.USE_CACHED_TASKS=False \
            --alsologtostderr \
            --multiprocess_gpu \
            --coordinator_address="${ADDR}" \
            --process_count 2 \
            --process_index 1
lintangsutawika commented 2 years ago

Closing this as the jax-related issue seems solved for now.

Thanks sudhakarsingh27

sudhakarsingh27 commented 2 years ago

I'd like to point out that although this worked with adding CUDA_VISIBLE_DEVICES, that's not a recommended way now, given that jax.distributed.initialize provides an argument to specify the GPU using local_devices_ids. (It'd be great if you could confirm that your code works with the recommended way as well)

lintangsutawika commented 1 year ago

I can confirm. I experimented on splitting 1 node of 8 GPUs to 2 process with 4 GPUs each using the method you mention. The issue was on the SLURM side.

However, I found a related issue where it seems the more nodes I use, the slower the throughput is.

I've tested running on nodes of 2,4, and 8 (each node has 8 A100s). I also varied the number of process (I made it so that increasing num of process reduces number of GPU per process). The trend seems to be for batch size fixed, increasing the number of nodes reduces steps per second.

I'm using T5X, but this feels like it might be an issue more relevant in Jax. Could it be misconfiguration?

steps_per_second

AranKomat commented 1 year ago

Hi. I'm working with Lintang on this issue.

The real problem isn't that it slows down when we increase the number of nodes while fixing the total number of GPUs, since this is likely due to the slow gradient aggregation of GPUs, which is expected. It is that significant slowdown occurs when we run with singularity and two nodes with 16 A100s in total compared with when we run without singularity and one node with 8 A100s in total.

We have 1.6 steps/sec in the former case and 4 steps/sec in the latter case. In both cases, the global batch size is 256, and we're using base sized T5. Lintang verified that, in the former case, the input batch is split to two batches of size 128, so we don't have issues like unsplit batch.

sudhakarsingh27 commented 1 year ago

Hi, A few clarifying questions:

  1. Are you running 1 GPU per process in both cases (i.e. for 2 nodes/16 GPUs and 1 node/8 GPUs)?
  2. Is this metric (steps/s) consistent across runs?
  3. How often is this metric logged (e.g. after X iterations)?
  4. Could you also share seqs_per_second and seqs_per_second_per_core?
  5. Are you also using model-parallelism (basically, num_partitions>1)?
  6. Could you share the device mesh info from the logs?
global_mesh axis_names: ('data', 'model')
global_mesh devices: [[StreamExecutorGpuDevice(id=0, process_index=0)]
 [StreamExecutorGpuDevice(id=1, process_index=1)]
 [StreamExecutorGpuDevice(id=8, process_index=2)]
...
lintangsutawika commented 1 year ago

Sure thing.

  1. I made it so it maximizes the number of GPUs per process. for example, 2Nodes with 2 process has 8 GPUs per process while 2 Nodes with 4 process has 4 GPUs per process. Each GPUs/process is color coded.
  2. The number seem to be consistent if the number of process matches the number of nodes and all 8 GPUs per nodes are used. There is higher variations when number of variations increase. But the trend is still the same.
  3. It's logged at the end of the test run, after 100 steps
  4. Listed in the image below (number updated because using a virtualenv which evidently is faster than using containers, probably due to optimized drivers)
  5. No, using num_partitions = 1
  6. Listing here for device mesh info from a few experiments

--2 Nodes 2 Process --

I1018 02:49:33.758280 140684459909120 partitioning.py:331] global_mesh axis_names: ('data', 'model')
I1018 02:49:33.758435 140684459909120 partitioning.py:332] global_mesh devices: [[StreamExecutorGpuDevice(id=0, process_index=0)]
 [StreamExecutorGpuDevice(id=1, process_index=0)]
 [StreamExecutorGpuDevice(id=2, process_index=0)]
 [StreamExecutorGpuDevice(id=3, process_index=0)]
 [StreamExecutorGpuDevice(id=4, process_index=0)]
 [StreamExecutorGpuDevice(id=5, process_index=0)]
 [StreamExecutorGpuDevice(id=6, process_index=0)]
 [StreamExecutorGpuDevice(id=7, process_index=0)]
 [StreamExecutorGpuDevice(id=8, process_index=1)]
 [StreamExecutorGpuDevice(id=9, process_index=1)]
 [StreamExecutorGpuDevice(id=10, process_index=1)]
 [StreamExecutorGpuDevice(id=11, process_index=1)]
 [StreamExecutorGpuDevice(id=12, process_index=1)]
 [StreamExecutorGpuDevice(id=13, process_index=1)]
 [StreamExecutorGpuDevice(id=14, process_index=1)]
 [StreamExecutorGpuDevice(id=15, process_index=1)]]

-- 2 Nodes 4 Process --

I1018 04:01:34.842234 140331303997440 partitioning.py:331] global_mesh axis_names: ('data', 'model')
I1018 04:01:34.842025 140325101854720 partitioning.py:332] global_mesh devices: [[StreamExecutorGpuDevice(id=0, process_index=0)]
 [StreamExecutorGpuDevice(id=1, process_index=0)]
 [StreamExecutorGpuDevice(id=2, process_index=0)]
 [StreamExecutorGpuDevice(id=3, process_index=0)]
 [StreamExecutorGpuDevice(id=4, process_index=1)]
 [StreamExecutorGpuDevice(id=5, process_index=1)]
 [StreamExecutorGpuDevice(id=6, process_index=1)]
 [StreamExecutorGpuDevice(id=7, process_index=1)]
 [StreamExecutorGpuDevice(id=8, process_index=2)]
 [StreamExecutorGpuDevice(id=9, process_index=2)]
 [StreamExecutorGpuDevice(id=10, process_index=2)]
 [StreamExecutorGpuDevice(id=11, process_index=2)]
 [StreamExecutorGpuDevice(id=12, process_index=3)]
 [StreamExecutorGpuDevice(id=13, process_index=3)]
 [StreamExecutorGpuDevice(id=14, process_index=3)]
 [StreamExecutorGpuDevice(id=15, process_index=3)]]

-- 2 Nodes 8 Process --

I1018 04:07:24.959209 140086820651008 partitioning.py:331] global_mesh axis_names: ('data', 'model')
I1018 04:07:24.959397 140086820651008 partitioning.py:332] global_mesh devices: [[StreamExecutorGpuDevice(id=0, process_index=0)]
 [StreamExecutorGpuDevice(id=1, process_index=0)]
 [StreamExecutorGpuDevice(id=2, process_index=1)]
 [StreamExecutorGpuDevice(id=3, process_index=1)]
 [StreamExecutorGpuDevice(id=4, process_index=2)]
 [StreamExecutorGpuDevice(id=5, process_index=2)]
 [StreamExecutorGpuDevice(id=6, process_index=3)]
 [StreamExecutorGpuDevice(id=7, process_index=3)]
 [StreamExecutorGpuDevice(id=8, process_index=4)]
 [StreamExecutorGpuDevice(id=9, process_index=4)]
 [StreamExecutorGpuDevice(id=10, process_index=5)]
 [StreamExecutorGpuDevice(id=11, process_index=5)]
 [StreamExecutorGpuDevice(id=12, process_index=6)]
 [StreamExecutorGpuDevice(id=13, process_index=6)]
 [StreamExecutorGpuDevice(id=14, process_index=7)]
 [StreamExecutorGpuDevice(id=15, process_index=7)]]
image