microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
33.56k stars 3.93k forks source link

[BUG]CUDA error in pipeline parallel #5536

Open sunkun1997 opened 1 month ago

sunkun1997 commented 1 month ago

Describe the bug When I trained the model with two nodes for pipeline parallel tasks, each node has eight graphics cards. So the incoming LOCAL_RANK of Node One is 8-17, and the line 201 of deepspeed/runtime/pipe/module. py
self. to (get_accelerator().device_name (self.local_rank)). Here local_rank does not match the graphics card numbers 0-7, so CUDA error is raised

loadams commented 1 month ago

Hi @sunkun1997 - can you please share more information on your setup, ds_config, ds_report, and sample repro script?

sunkun1997 commented 1 month ago

repro script I just use the https://github.com/microsoft/DeepSpeedExamples/tree/master/training/pipeline_parallelism example, but in order to fit our environment, I need to make a slight modification. In our environment, each node has four environment variables: the number of nodes WORLD_SIZE, the node rank RANK, the master node ip MASTER_ADDR, the port MASTER_PORT. So I modified run.sh

gpu=8  
n=$(($WORLD_SIZE * $gpu))  
start_rank=$(($RANK * $gpu))  
end_rank=$((($RANK + 1) * $gpu))  
for ((i=$start_rank; i<$end_rank; i++))  
do  
  {  
    LOCAL_RANK=$i WORLD_SIZE=$n MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT python train.py \  
    --p $gpu \  
    --steps=200  
  }&  
done  
wait

And modified the main of train.py

if __name__ == '__main__':  
    import json  
    args = get_args()  
    with open('./ds_config.json', 'r') as f:  
        args.deepspeed_config = json.loads(f.read())  
    args.local_rank = int(os.environ['LOCAL_RANK'])  
    args.world_size = int(os.environ['WORLD_SIZE'])  
    deepspeed.init_distributed(dist_backend=args.backend, rank=args.local_rank,  
                               world_size=args.world_size, auto_mpi_discovery=False)  
    torch.cuda.set_device(args.local_rank % 8)  
    if args.pipeline_parallel_size == 0:  
        train_base(args)  
    else:  
        train_pipe(args)  

Then Each node run the run.sh. ds_report raise Cuda error

Traceback (most recent call last):
  File "train.py", line 165, in <module>
    train_pipe(args)
  File "train.py", line 131, in train_pipe
    net = PipelineModule(layers=join_layers(net),
  File "/home/ray/anaconda3/lib/python3.8/site-packages/deepspeed/runtime/pipe/module.py", line 201, in __init__
    self.to(get_accelerator().device_name(self.local_rank))
  File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1145, in to
    return self._apply(convert)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 797, in _apply
    module._apply(fn)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 820, in _apply
    param_applied = fn(param)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1143, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
sunkun1997 commented 1 month ago

By the way, If I modify the start train.py with LOCAL_RANK=$((i % $gpu)) RANK=$i WORLD_SIZE=$n MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT python train.py and modify the start distributed environment with deepspeed.init_distributed(dist_backend="nccl", rank=args.rank, world_size=args.world_size, auto_mpi_discovery=False). It looks like the nodes can't communicate with each other and raise

Traceback (most recent call last):
  File "train.py", line 166, in <module>
    train_pipe(args)
  File "train.py", line 137, in train_pipe
    trainset = cifar_trainset(args.local_rank)
  File "train.py", line 30, in cifar_trainset
    dist.barrier()
  File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3328, in barrier
    work = default_pg.barrier(opts=opts)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.