Closed Lauler closed 1 year ago
Seems possibly related to this Lightning issue: https://github.com/Lightning-AI/lightning/issues/10098
Have you tried the following config:
...
#SBATCH --nodes=<n>
#SBATCH --tasks-per-node=<m>
#SBATCH --gpus-per-node=<m>
...
trainer.devices=-1 \
trainer.num_nodes=$SLURM_JOB_NUM_NODES \
This will start up SLURM_TASKS_PER_NODE
tasks on each node, each will have access to SLURM_GPUS_PER_NODE
GPUs, but each slurm run task will use only one GPU.
As for other potential causes, they may depend on the specifics of your setup, for those you may wish to enable the following:
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
It helped me in tracking down an issue I had with --gpus-per-task
, see https://github.com/NVIDIA/pyxis/issues/73.
Thanks for sharing and suggesting fixes. When I try your suggested config of
...
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --ntasks-per-node=4
...
trainer.devices=-1 \
trainer.num_nodes=$SLURM_JOB_NUM_NODES \
it still has problems setting the global ranks correctly. All the processes now are GLOBAL_RANK: 0
. Additionally it doesn't seem to have a proper sense of the world size.
0: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
1: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
2: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
...
...
7: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
In all the different configurations of sbatch params I've tried, Pytorch Lightning seems to have issues setting GLOBAL_RANK
correctly. I think it gets set and calculated here: https://github.com/Lightning-AI/lightning/blob/cc56539cd3b3875d9d374d55004b4b86e07b47a9/src/pytorch_lightning/strategies/ddp.py#L194, but I honestly have trouble understanding how where the values come from with all those levels of inheritance and imports.
It results in all processes except one crashing when they connect to the same address:
0: Error executing job with overrides: ['trainer.devices=-1', 'trainer.num_nodes=2', 'trainer.max_epochs=null', 'trainer.max_steps=300000', 'trainer.val_check_interval=300', 'trainer.log_every_n_steps=50', 'trainer.limit_val_batches=50', 'trainer.limit_test_batches=50', 'trainer.accumulate_grad_batches=1', 'trainer.precision=16', 'model.micro_batch_size=6', 'model.global_batch_size=192', 'model.tensor_model_parallel_size=1', 'model.pipeline_model_parallel_size=1', 'model.max_position_embeddings=1024', 'model.encoder_seq_length=1024', 'model.hidden_size=768', 'model.ffn_hidden_size=3072', 'model.num_layers=12', 'model.num_attention_heads=12', 'model.init_method_std=0.021', 'model.hidden_dropout=0.1', 'model.layernorm_epsilon=1e-5', 'model.tokenizer.vocab_file=gpt2-vocab.json', 'model.tokenizer.merge_file=gpt2-merges.txt', 'model.data.data_prefix=[1.0,hfbpe_gpt_training_data_text_document]', 'model.data.num_workers=64', 'model.data.seq_length=1024', "model.data.splits_string='980,10,10'", 'model.optim.name=fused_a
0: dam', 'model.optim.lr=6e
0: -4', 'model.optim.betas=[0.9,0.95]', 'model.optim.weight_decay=0.1', 'model.optim.sched.name=CosineAnnealing', 'model.optim.sched.warmup_steps=750', 'model.optim.sched.constant_steps=80000', 'model.optim.sched.min_lr=6e-5', 'exp_manager.resume_if_exists=True', 'exp_manager.resume_ignore_no_checkpoint=True', 'exp_manager.create_checkpoint_callback=True', 'exp_manager.checkpoint_callback_params.monitor=val_loss', 'exp_manager.checkpoint_callback_params.save_top_k=3', 'exp_manager.checkpoint_callback_params.mode=min', 'exp_manager.checkpoint_callback_params.always_save_nemo=False']
0: Traceback (most recent call last):
0: File "/workspace/nemo/examples/nlp/language_modeling/megatron_gpt_pretraining.py", line 88, in main
0: trainer.fit(model)
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
0: self._call_and_handle_interrupt(
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
0: return self.strategy.launcher.la
0: unch(trainer_fn, *args, trainer=self, **kwargs)
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
0: return function(*args, **kwargs)
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
0: results = self._run(model, ckpt_path=self.ckpt_path)
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1102, in _run
0: self.strategy.setup_environment()
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 157, in setup_environment
0: self.setup_distributed()
0: File "/opt/conda/lib/python3.8/site-packages/nemo/collections/nlp/parts/nlp_overrides.py", line 81, in setup_distributed
0: super().setup_distributed()
0: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 210, in setup_distributed
0: init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
0: File
0: "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 374, in init_dist_connection
0: torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
0: File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 627, in init_process_group
0: store, rank, world_size = next(rendezvous_iterator)
0: File "/opt/conda/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 246, in _env_rendezvous_handler
0: store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
0: File "/opt/conda/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 177, in _create_c10d_store
0: return TCPStore(
0: RuntimeError: The server socket has failed to listen on any local network address. The server socket has fa
0: iled to bind to [::]:53394 (errno: 98 - Address already in use). The server socket has failed to bind to 0.0.0.0:53394 (errno: 98 - Address already in use).
0:
0: Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
If I change trainer.devices=-1
to trainer.devices=4
I instead get the following error:
4: Error executing job with overrides: ['trainer.devices=4', 'trainer.num_nodes=2', 'trainer.max_epochs=null', 'trainer.max_steps=300000', 'trainer.val_check_interval=300', 'trainer.log_every_n_steps=50', 'trainer.limit_val_batches=50', 'trainer.limit_test_batches=50', 'trainer.accumulate_grad_batches=1', 'trainer.precision=16', 'model.micro_batch_size=6', 'model.global_batch_size=192', 'model.tensor_model_parallel_size=1', 'model.pipeline_model_parallel_size=1', 'model.max_position_embeddings=1024', 'model.encoder_seq_length=1024', 'model.hidden_size=768', 'model.ffn_hidden_size=3072', 'model.num_layers=12', 'model.num_attention_heads=12', 'model.init_method_std=0.021', 'model.hidden_dropout=0.1', 'model.layernorm_epsilon=1e-5', 'model.tokenizer.vocab_file=gpt2-vocab.json', 'model.tokenizer.merge_file=gpt2-merges.txt', 'model.data.data_prefix=[1.0,hfbpe_gpt_training_data_text_document]', 'model.data.num_workers=64', 'model.data.seq_length=1024', "model.data.splits_string='980,10,10'", 'model.optim.name=fused_ad
4: am', 'model.optim.lr=6e-
4: 4', 'model.optim.betas=[0.9,0.95]', 'model.optim.weight_decay=0.1', 'model.optim.sched.name=CosineAnnealing', 'model.optim.sched.warmup_steps=750', 'model.optim.sched.constant_steps=80000', 'model.optim.sched.min_lr=6e-5', 'exp_manager.resume_if_exists=True', 'exp_manager.resume_ignore_no_checkpoint=True', 'exp_manager.create_checkpoint_callback=True', 'exp_manager.checkpoint_callback_params.monitor=val_loss', 'exp_manager.checkpoint_callback_params.save_top_k=3', 'exp_manager.checkpoint_callback_params.mode=min', 'exp_manager.checkpoint_callback_params.always_save_nemo=False']
4: Traceback (most recent call last):
4: File "/workspace/nemo/examples/nlp/language_modeling/megatron_gpt_pretraining.py", line 64, in main
4: trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/argparse.py", line 345, in insert_env_defaults
4: return fn(self, **kwargs)
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 433, in __
4: init__
4: self._accelerator_connector = AcceleratorConnector(
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 214, in __init__
4: self._set_parallel_devices_and_init_accelerator()
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 545, in _set_parallel_devices_and_init_accelerator
4: self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/cuda.py", line 77, in parse_devices
4: return device_parser.parse_gpu_ids(devices, include_cuda=True)
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/device_parser.py", line 125, in parse_gpu_ids
4: return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
4: File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/device_parser.py", line 209, in _sanitize_gpu_ids
4: raise MisconfigurationException(
4: pyto
4: rch_lightning.utilities.exceptions.MisconfigurationException: You requested gpu: [0, 1, 2, 3]
4: But your machine only has: [0]
We haven't been using this pattern of SLURM_TASKS_PER_NODE
being equal to GPUs per node previously when training with Megatron-LM and launching jobs with torch.distributed.launch
. There we launch with one process per node.
Seems like the likely culprit in our case is that only 1 GPU looks to be available per process when running torch.cuda.device.count()
.
However, all 4 GPU devices show up in each of the individual processes when running nvidia-smi
or nvidia-smi -L
.
*edit:
Although
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
gives only single visible devices for each task :unamused: .
We found a rather hacky solution to make training work. To anyone reading this in the future who runs in to the same issue:
Our problem was that all GPUs in a node were not visible to a process whenever we started Slurm jobs with the recommendation of --ntasks-per-node
being equal to trainer.devices
. The aforementioned resulted in only 1 GPU being visible per process, and we weren't able to rectify that.
--nodes=2
and --gres:gpu:4
. Export the MASTER_ADDR
as an environment variable.GLOBAL_RANK
of the current process in the bash script of the fiel that launches the training. Do NOT export the LOCAL_RANK
as an environment variable (Lightning/Pytorch DDP will get stuck without initializing the rest of the GPUs if you do). For reference, here's our sbatch-script:
#!/bin/bash -l
#SBATCH --partition=gpu
#SBATCH --qos=test
#SBATCH --account=p200097
#SBATCH --job-name=gpt_nemo
#SBATCH --nodes=4
#SBATCH --gres=gpu:4
#SBATCH --time=0-00:30:00
#SBATCH --output=logs/gpt_nemo.log
# Modules
pwd
module purge
module load Singularity-CE
## Create needed distributed env variables
addr=$(/bin/hostname -s)
export MASTER_ADDR=$addr
export MASTER_PORT=16783 # Meluxina overwrites this variable after srun
export GPUS_PER_NODE=4
export NCCL_CROSS_NIC=1
# debugging flags (optional)
export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
export PYTHONFAULTHANDLER=1
export HYDRA_FULL_ERROR=1
# Logfile
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
PROJECT=/project/home/p200097/faton/nemo_test # Use abs path, not symbolic link
CONTAINER_PATH=/project/home/p200097/faton/nemo_test/nemo2209b.sif
LOGGING=$PROJECT/logs
LOGFILE="${LOGGING}/%x_${DATETIME}.log"
echo $LOGFILE
ls -lh
cmd="srun -l --output=$LOGGING/gpt_nemo_$DATETIME.log \
singularity exec --nv --bind $PROJECT:$PROJECT --bind /project/scratch/p200097/data/nemo_test:/mnt $CONTAINER_PATH \
bash $PROJECT/training_args.sh"
$cmd
And here's our training_args.sh
that launches the training for each process (1 process per node):
/bin/hostname -s
export MASTER_PORT=16783
export NODE_RANK=$SLURM_NODEID
# export LOCAL_RANK=$SLURM_LOCALID # Local rank needs to be uninitialized for Lightning to work properly with DDP and 1 process per node
# export GLOBAL_RANK=$SLURM_PROCID # if --ntasks-per-node == devices, then PROCID is the global_rank. But training with --ntasks-per-node doesn't work.
export GLOBAL_RANK=$((SLURM_NODEID * GPUS_PER_NODE + LOCAL_RANK)) # When only 1 process per node, this calculates global_rank
echo "----------"
echo "NODE_RANK" $NODE_RANK
echo "LOCAL_RANK" $LOCAL_RANK
echo "GLOBAL_RANK" $GLOBAL_RANK
echo "WORLD_SIZE" $WORLD_SIZE
echo "MASTER_PORT" $MASTER_PORT
echo "---------------------"
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
nvidia-smi -L
python /workspace/nemo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \
--config-path=/workspace/nemo/examples/nlp/language_modeling/conf \
--config-name=megatron_gpt_config \
trainer.devices=$GPUS_PER_NODE \
...bunch-of-args \
...
Hope this helps someone in the future trying to train multi-node with NeMo and Slurm.
We use slurm for all our clusters, none of the above is needed. We follow PTL guidelines, and the only thing we normally do is add cuda visible devices flag with all the GPUs in the list. That seems to work fine without resorting to these steps
So if there are 8 GPUs per node, we do CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7,8 python nemo_script.py ... trainer.num_nodes=x trainer.devices=-1
Thank you very much @titu1994 . It had not ever occurred to me that CUDA_VISIBLE_DEVICES
would be something a user would manually want or need to edit. Always assumed it would be something correctly or appropriately set by the system and not for a user to touch.
It would probably be helpful if you guys posted an example sbatch script in the documentation, to save some others from future headache.
Thanks again for the tip about setting CUDA_VISIBLE_DEVICES
, it works perfectly with PTL guidelines.
@SeanNaren we should note this in your aws tutorial (though I dunno if that uses slurm directly or AWS sage maker). Maybe also let's comment on PTL slack to add this info to the end of https://pytorch-lightning.readthedocs.io/en/stable/clouds/cluster_advanced.html
https://pytorch-lightning.readthedocs.io/en/stable/clouds/cluster_advanced.html
Describe the bug
We are trying to get multi-node training to work with NeMo Megatron by following the steps in the quick start steps in your GPT model training docs. We're using Slurm on an HPC, and are able to successfully train using Megatron-LM, but not with NeMo.
NeMo keeps insisting we are running multi-node training without SLURM handling the processes:
and the global ranks of our GPUs seem to be incorrectly initialised as result:
Steps/Code to reproduce bug
%environment export LC_ALL=C
And here are the setting and launch script in
training_args.sh
:Expected behavior
Nemo/Pytorch Lightning recognizing job is run through slurm and starting the job successfully.
Environment overview (please complete the following information)
Additional context 1 node in our case consists of 4 A100 GPUs.
We saw that you referred to the Pytorch Lightning documentation when asked about multi-node training in this previous issue. However, the Pytorch Lightning docs' example sbatch script has a setting that makes no sense to us:
If we set
--ntasks-per-node=4
this creates 4 separate processes in a node consisting of 4 GPUs, and each GPU is placed in a separate process, with only a single GPU being available per process. We tried the above method, and it only resulted in training crashing because NeMo/Lightning expected 4 devices (0, 1, 2, 3
) but only saw one device (0
) per process.In the github issue thread we referenced, you write that you guys use Slurm internally. Could you provide a working example of launching a multi-node job with NeMo Megatron using sbatch and the example in your docs?
Log outputs: