Closed LemonAndRabbit closed 4 months ago
Once Frontier is back, I will try to run Frontier. I have a quick question: how I can specify level of optimizers, such as zero stage 1, 2, 3, etc.?
Once Frontier is back, I will try to run Frontier. I have a quick question: how I can specify level of optimizers, such as zero stage 1, 2, 3, etc.?
@jychoi-hpc
You can specify the deepspeed configuration in your configuration file's ["NeuralNetwork"]["ds_config"]
entry. To enable the zero optimizer, you can pass the following dictionary to that entry:
"zero_optimization": {
"stage": 1,
// other zero optimizer configurations, should work fine if you only pass the "stage" parameter
}
Please note that only the OGB example here is ready to run deepspeed model engine. Other examples need manual modifications similar to this file to enable deepspeed training.
I tried to run on Frontier but got the following error. Do you have any quick advice? In the meantime, I will try to run on Perlmutter.
File "/lustre/orion/cph161/world-shared/jyc/frontier/HydraGNN-pr264/./examples/ogb/train_gap.py", line 487, in <module>
model, optimizer, _, _ = deepspeed.initialize(
File "/lustre/orion/world-shared/cph161/jyc/frontier/sw/envs/hydragnn-py39-rocm571-amd/lib/python3.9/site-packages/deepspeed/__init__.py", line 181, in initialize
engine = DeepSpeedEngine(args=args,
File "/lustre/orion/world-shared/cph161/jyc/frontier/sw/envs/hydragnn-py39-rocm571-amd/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 248, in __init__
self._set_distributed_vars(args)
File "/lustre/orion/world-shared/cph161/jyc/frontier/sw/envs/hydragnn-py39-rocm571-amd/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 970, in _set_distributed_vars
get_accelerator().set_device(device_rank)
File "/lustre/orion/world-shared/cph161/jyc/frontier/sw/envs/hydragnn-py39-rocm571-amd/lib/python3.9/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 66, in set_device
torch.cuda.set_device(device_index)
File "/autofs/nccs-svm1_sw/crusher/amdsw/karldev/pytorch-2.2.2-rocm5.7.1/torch/cuda/__init__.py", line 408, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: HIP error: invalid device ordinal
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing HIP_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
FYI, my job script is here: /lustre/orion/world-shared/cph161/jyc/frontier/HydraGNN-pr264/job-ogb.sh
Hi @jychoi-hpc, I once encountered a similar issue on my machine, last time the cause was that the LOCAL_RANK (which is used to determine the GPU ID for a process) mismatches the visible GPU ID of that process. I am not sure if this is the same case in your script, I can double check tmr on the frontier machine.
I am not familiar with Slurm options, could the --gpu-bind
option (or other options) result in a limited GPU visibility to the processes so that each process only sees one GPU and thus thinks any device ordinal other than 0 as illegal?
@allaffa
Update in the latest commit:
unittest_train_model()
in tests/test_graphs.py, run_training
in hydragnn/run_training.py, and run_predication
in hydragnn/run_prediction.py to ensure compatibility with the newly added DeepSpeed backend. This update aims to simplify the development of new unit tests for the DeepSpeed backend and facilitate switching existing implementations to it.@jychoi-hpc You may uncomment these lines to test zero-optimizer on a GPU server (like frontier or Perlmutter).
Hi @jychoi-hpc, I once encountered a similar issue on my machine, last time the cause was that the LOCAL_RANK (which is used to determine the GPU ID for a process) mismatches the visible GPU ID of that process. I am not sure if this is the same case in your script, I can double check tmr on the frontier machine.
I am not familiar with Slurm options, could the
--gpu-bind
option (or other options) result in a limited GPU visibility to the processes so that each process only sees one GPU and thus thinks any device ordinal other than 0 as illegal?
Yes, --gpu-bind=closest
can make mismatching between LOCAL_RANK and GPU ID. We usually assign each GPU to a single process. I think this is not a way how to run with DeepSpeed. I need to learn :) Can you share your srun command you are using on Perlmutter?
Hi Max @allaffa,
I have introduced a new parameter, overwrite_config
, for the unittest_train_model()
function in tests/test_graphs.py. This addition eliminates the need for creating additional configuration files.
Hi @jychoi-hpc, I once encountered a similar issue on my machine, last time the cause was that the LOCAL_RANK (which is used to determine the GPU ID for a process) mismatches the visible GPU ID of that process. I am not sure if this is the same case in your script, I can double check tmr on the frontier machine. I am not familiar with Slurm options, could the
--gpu-bind
option (or other options) result in a limited GPU visibility to the processes so that each process only sees one GPU and thus thinks any device ordinal other than 0 as illegal?Yes,
--gpu-bind=closest
can make mismatching between LOCAL_RANK and GPU ID. We usually assign each GPU to a single process. I think this is not a way how to run with DeepSpeed. I need to learn :) Can you share your srun command you are using on Perlmutter?
Hi @jychoi-hpc, I am using this script for testing on Perlmutter:
#!/bin/bash
#SBATCH -A m4716
#SBATCH -J HydraGNN
#SBATCH -C gpu
#SBATCH -q regular
#SBATCH -t 1:00:00
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH -N 1
## remove write permission for others in terms of newly created files and dirs
umask 002
## Module
module reset
module load pytorch/2.0.1
HYDRAGNN_DIR=/global/cfs/cdirs/m4133/zye327/HydraGNN
module use -a /global/cfs/cdirs/m4133/jyc/perlmutter/sw/modulefiles
module load hydragnn/pytorch2.0.1-v2
echo "python:" `which python`
export PYTHONPATH=$HYDRAGNN_DIR:$PYTHONPATH
## Envs
export MPICH_ENV_DISPLAY=0
export MPICH_VERSION_DISPLAY=0
export MPICH_GPU_SUPPORT_ENABLED=0
export HYDRAGNN_NUM_WORKERS=0
export HYDRAGNN_USE_VARIABLE_GRAPH_SIZE=1
export HYDRAGNN_AGGR_BACKEND=mpi
export HYDRAGNN_VALTEST=1
set -x
srun -N1 -n4 -c32 --gpus-per-task=1 \
python train_gap.py --adios gap \
2>&1 | tee run-ogb-gap-adios.log
set +x
I tested on Perlmutter but got the same error:
File "/global/cfs/cdirs/m4133/jyc/perlmutter/HydraGNN-ds/examples/ogb/train_gap.py", line 485, in <module>
INFO (rank 0): Added key: store_based_barrier_key:2 to store for rank: 0
model, optimizer, _, _ = deepspeed.initialize(
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/__init__.py", line 181, in initialize
model, optimizer, _, _ = deepspeed.initialize(
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/__init__.py", line 181, in initialize
engine = DeepSpeedEngine(args=args,
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 248, in __init__
engine = DeepSpeedEngine(args=args,
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 248, in __init__
self._set_distributed_vars(args)
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 970, in _set_distributed_vars
self._set_distributed_vars(args)
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 970, in _set_distributed_vars
get_accelerator().set_device(device_rank)
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 66, in set_device
get_accelerator().set_device(device_rank)
File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 66, in set_device
torch.cuda.set_device(device_index)
File "/global/common/software/nersc/pm-2022q4/sw/pytorch/2.0.1/lib/python3.9/site-packages/torch/cuda/__init__.py", line 350, in set_device
torch.cuda.set_device(device_index)
File "/global/common/software/nersc/pm-2022q4/sw/pytorch/2.0.1/lib/python3.9/site-packages/torch/cuda/__init__.py", line 350, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
I installed deepspeed using pip: pip install deepspeed
and the version is:
deepspeed 0.14.4
Which version of deepspeed are you using on Perlmutter? Can the error be related with the version?
I tested on Perlmutter but got the same error:
File "/global/cfs/cdirs/m4133/jyc/perlmutter/HydraGNN-ds/examples/ogb/train_gap.py", line 485, in <module> INFO (rank 0): Added key: store_based_barrier_key:2 to store for rank: 0 model, optimizer, _, _ = deepspeed.initialize( File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/__init__.py", line 181, in initialize model, optimizer, _, _ = deepspeed.initialize( File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/__init__.py", line 181, in initialize engine = DeepSpeedEngine(args=args, File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 248, in __init__ engine = DeepSpeedEngine(args=args, File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 248, in __init__ self._set_distributed_vars(args) File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 970, in _set_distributed_vars self._set_distributed_vars(args) File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 970, in _set_distributed_vars get_accelerator().set_device(device_rank) File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 66, in set_device get_accelerator().set_device(device_rank) File "/global/cfs/cdirs/m4133/jyc/perlmutter/sw/hydragnn-pytorch2.0.1-v2/lib/python3.9/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 66, in set_device torch.cuda.set_device(device_index) File "/global/common/software/nersc/pm-2022q4/sw/pytorch/2.0.1/lib/python3.9/site-packages/torch/cuda/__init__.py", line 350, in set_device torch.cuda.set_device(device_index) File "/global/common/software/nersc/pm-2022q4/sw/pytorch/2.0.1/lib/python3.9/site-packages/torch/cuda/__init__.py", line 350, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
I installed deepspeed using pip:
pip install deepspeed
and the version is:deepspeed 0.14.4
Which version of deepspeed are you using on Perlmutter? Can the error be related with the version?
I am also using deepspeep 0.14.4 so it should not be the reason, I am double checking on Perlmutter to see if there are some problems in recent commits. Last time I checked this commit and the script works fine on Perlmutter.
I think this is related with srun. DeepSpeed wants to know my neighbor GPUs, instead of assigning one GPU per process (which is HydraGNN's default approach).
I found the following srun works on Perlmutter:
srun -N1 -n4 -c32 --gres=gpu:4 \
python -u examples/ogb/train_gap.py gap --adios --use_deepspeed
To compare, this is a srun command without deepspeed;
srun -N1 -n4 -c32 --gpus-per-task=1 \
python -u examples/ogb/train_gap.py gap --adios
I don't know if this is the correct way of running DeepSpeed.
Since the code is working on Perlmutter, I have no issue in merging to the main. We may need to figure out the best way of running with DeepSpeed later.
Hi @jychoi-hpc, thank you for testing the code!
I will look into the best way to launch the deepspeed backend.
FYI, I just tested running on Frontier using srun --gres=gpu:8
option.
This pull request introduces support for training with the DeepSpeed model engine and lays the groundwork for Chaojian's pipeline parallelism training implementation.
To enable deepspeed in HydraGNN framework, we have made the following modifications to the codebase:
setup_ddp()
function is updated to be compatible with deepspeed distributed backend initialization;parse_deepspeed_config()
function has been added to parse DeepSpeed configurations from the HydraGNN configuration dictionary under the["NeuralNetwork"]["ds_config"]
entry.train_validate_test()
andtrain()
functions have been modified to be compatible with the DeepSpeed model engine training API.Note that these methods remain compatible with their previous APIs used elsewhere in the codebase.
To validate the correctness of our implementation, we used the train_ogb.py script as an example. Users can switch between the torch.distributed and DeepSpeed model engine backends by passing or omitting the --use_deepspeed flag. Training with deepspeed backend results in similar training throughput and 33% reduced peak memory consumption on Perlumtter (when not enabling Deepspeed-Zero optimizer and other deepspeed optimizations).