hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.74k stars 4.34k forks source link

[BUG]: Deadlock when using gemini and pipeline at the same time #3383

Open liuzeming-yuxi opened 1 year ago

liuzeming-yuxi commented 1 year ago

🐛 Describe the bug

Hi~ We tried to use pipeline parallel + gemini to train a model.But it seems that there was a deadlock during communation.The following is a simple reproduction based on the official example

import os
import torch
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import is_using_pp, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup
import torch.distributed as dist

class DummyDataloader():
    def __init__(self, length, batch_size):
        self.length = length
        self.batch_size = batch_size
    def generate(self):
        data = torch.rand(self.batch_size, 3, 224, 224)
        label = torch.randint(low=0, high=10, size=(self.batch_size,))
        return data, label
    def __iter__(self):
        self.step = 0
        return self
    def __next__(self):
        if self.step < self.length:
            self.step += 1
            return self.generate()
        else:
            raise StopIteration
    def __len__(self):
        return self.length

def main():
    # launch from torch
    parser = colossalai.get_default_parser()
    args = parser.parse_args()
    colossalai.launch_from_torch(config=args.config)

    # get logger
    logger = get_dist_logger()
    logger.info("initialized distributed environment", ranks=[0])
    if hasattr(gpc.config, 'LOG_PATH'):
        if gpc.get_global_rank() == 0:
            log_path = gpc.config.LOG_PATH
            if not os.path.exists(log_path):
                os.mkdir(log_path)
            logger.log_to_file(log_path)

    # create model
    model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
                        patch_size=gpc.config.PATCH_SIZE,
                        hidden_size=gpc.config.HIDDEN_SIZE,
                        depth=gpc.config.DEPTH,
                        num_heads=gpc.config.NUM_HEADS,
                        mlp_ratio=gpc.config.MLP_RATIO,
                        num_classes=10,
                        init_method='jax',
                        checkpoint=gpc.config.CHECKPOINT)
    pipelinable = PipelinableContext()
    pg = ProcessGroup(ranks = [dist.get_rank()])
    with ColoInitContext(device=torch.device('cuda'), dtype=torch.half, default_pg=pg):
        with pipelinable:
            model = _create_vit_model(**model_kwargs)
    pipelinable.to_layer_list()
    pipelinable.policy = "uniform"
    model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
    model = GeminiDDP(model, device=get_current_device(), placement_policy='cpu', pin_memory=True)

    # use synthetic dataset
    # we train for 10 steps and eval for 5 steps per epoch
    train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
    test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)

    # create loss function
    criterion = CrossEntropyLoss(label_smoothing=0.1)

    # create optimizer
    optimizer = GeminiAdamOptimizer(model, lr=gpc.config.LEARNING_RATE)

    # initialize
    engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
                                                                         optimizer=optimizer,
                                                                         criterion=criterion,
                                                                         train_dataloader=train_dataloader,
                                                                         test_dataloader=test_dataloader)

    data_iter = iter(train_dataloader)
    logger.info("Engine is built", ranks=[0])
    for epoch in range(gpc.config.NUM_EPOCHS):
        # training
        engine.zero_grad()
        _, _, loss =  engine.execute_schedule(data_iter, return_output_label=False)
        engine.step()
        if dist.get_rank() == logger_rank:
            logger.info(f"step: {step}, loss: {loss.item()}")

if __name__ == '__main__':
    main()

In order to locate the bug, we attempted to print c10dDEBUG when calling PyTorch c10d to display the call information. The following is the corresponding log for the above code:

[03/31/23 13:04:17] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:521 set_device                      
[03/31/23 13:04:17] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:521 set_device                      
                    INFO     colossalai - colossalai - INFO: process rank 3 is  
                             bound to device 3                                  
c10dDEBUG rank:3 all_gather_object
[03/31/23 13:04:17] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:521 set_device                      
                    INFO     colossalai - colossalai - INFO: process rank 0 is  
                             bound to device 0                                  
c10dDEBUG rank:0 all_gather_object
                    INFO     colossalai - colossalai - INFO: process rank 2 is  
                             bound to device 2                                  
c10dDEBUG rank:2 all_gather_object
[03/31/23 13:04:17] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:521 set_device                      
                    INFO     colossalai - colossalai - INFO: process rank 1 is  
                             bound to device 1                                  
c10dDEBUG rank:1 all_gather_object
c10dDEBUG rank:0 all_gather
c10dDEBUG rank:1 all_gather
c10dDEBUG rank:3 all_gather
c10dDEBUG rank:2 all_gather
c10dDEBUG rank:1 all_gather
c10dDEBUG rank:0 all_gather
c10dDEBUG rank:3 all_gather
c10dDEBUG rank:2 all_gather
[03/31/23 13:04:19] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:557 set_seed                        
[03/31/23 13:04:19] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:557 set_seed                        
[03/31/23 13:04:19] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:557 set_seed                        
[03/31/23 13:04:19] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/context/paral
                             lel_context.py:557 set_seed                        
                    INFO     colossalai - colossalai - INFO: initialized seed on
                             rank 3, numpy: 1024, python random: 1024,          
                             ParallelMode.DATA: 1024, ParallelMode.TENSOR:      
                             4096,the default parallel seed is                  
                             ParallelMode.DATA.                                 
                    INFO     colossalai - colossalai - INFO: initialized seed on
                             rank 0, numpy: 1024, python random: 1024,          
                             ParallelMode.DATA: 1024, ParallelMode.TENSOR:      
                             1024,the default parallel seed is                  
                             ParallelMode.DATA.                                 
                    INFO     colossalai - colossalai - INFO: initialized seed on
                             rank 2, numpy: 1024, python random: 1024,          
                             ParallelMode.DATA: 1024, ParallelMode.TENSOR:      
                             3072,the default parallel seed is                  
                             ParallelMode.DATA.                                 
                    INFO     colossalai - colossalai - INFO: initialized seed on
                             rank 1, numpy: 1024, python random: 1024,          
                             ParallelMode.DATA: 1024, ParallelMode.TENSOR:      
                             2048,the default parallel seed is                  
                             ParallelMode.DATA.                                 
                    INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/initialize.py
                             :116 launch                                        
                    INFO     colossalai - colossalai - INFO: Distributed        
                             environment is initialized, data parallel size: 1, 
                             pipeline parallel size: 4, tensor parallel size: 1 
                    INFO     colossalai - colossalai - INFO: train.py:53 main   
                    INFO     colossalai - colossalai - INFO: initialized        
                             distributed environment                            
[03/31/23 13:04:19] INFO     colossalai - ProcessGroup - INFO:                  
                             /workspace/code/ColossalAI/colossalai/tensor/proces
                             s_group.py:22 log_pg_init                          
                    INFO     colossalai - ProcessGroup - INFO: Pytorch          
                             ProcessGroup Init:                                 
                                     backend: nccl                              
                                     ranks: [0]                                 
                    INFO     colossalai - ProcessGroup - INFO:                  
                             /workspace/code/ColossalAI/colossalai/tensor/proces
                             s_group.py:22 log_pg_init                          
                    INFO     colossalai - ProcessGroup - INFO: Pytorch          
                             ProcessGroup Init:                                 
                                     backend: nccl                              
                                     ranks: [1]                                 
                    INFO     colossalai - ProcessGroup - INFO:                  
                             /workspace/code/ColossalAI/colossalai/tensor/proces
                             s_group.py:22 log_pg_init                          
                    INFO     colossalai - ProcessGroup - INFO: Pytorch          
                             ProcessGroup Init:                                 
                                     backend: nccl                              
                                     ranks: [2]                                 
                    INFO     colossalai - ProcessGroup - INFO:                  
                             /workspace/code/ColossalAI/colossalai/tensor/proces
                             s_group.py:22 log_pg_init                          
                    INFO     colossalai - ProcessGroup - INFO: Pytorch          
                             ProcessGroup Init:                                 
                                     backend: nccl                              
                                     ranks: [3]                                 
                    INFO     colossalai - ProcessGroup - INFO:                  
                             /workspace/code/ColossalAI/colossalai/tensor/proces
                             s_group.py:22 log_pg_init                          
                    INFO     colossalai - ProcessGroup - INFO: Pytorch          
                             ProcessGroup Init:                                 
                                     backend: nccl                              
                                     ranks: [0, 1, 2, 3]                        
searching chunk configuration is completed in 0.00 s.
used number: 0.02 MB, wasted number: 0.00 MB
total wasted percentage is 0.00%
False
False
False
[extension] Compiling or loading the JIT-built fused_optim kernel during runtime now
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/colossalai/torch_extensions/torch1.13_cu11.7/build.ninja...
Building extension module fused_optim...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_optim...
c10dDEBUG rank:3 recv, shape:torch.Size([]), device:cuda:3, src:2, group:None, tag:0
Loading extension module fused_optim...
Loading extension module fused_optim...
Loading extension module fused_optim...
[extension] Time to compile or load fused_optim op: 0.4035046100616455 seconds
c10dDEBUG rank:1 recv, shape:torch.Size([]), device:cuda:1, src:0, group:None, tag:0
[03/31/23 13:04:21] INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/initialize.py
                             :265 initialize                                    
                    INFO     colossalai - colossalai - INFO:                    
                             ========== Your Config ========                    
                             {'BATCH_SIZE': 4,                                  
                              'CHECKPOINT': False,                              
                              'DEPTH': 5,                                       
                              'HIDDEN_SIZE': 16,                                
                              'IMG_SIZE': 224,                                  
                              'LEARNING_RATE': 0.003,                           
                              'MLP_RATIO': 2,                                   
                              'NUM_CHUNKS': 2,                                  
                              'NUM_CLASSES': 10,                                
                              'NUM_EPOCHS': 2,                                  
                              'NUM_HEADS': 4,                                   
                              'NUM_MICRO_BATCHES': 4,                           
                              'PATCH_SIZE': 16,                                 
                              'SEQ_LENGTH': 197,                                
                              'TENSOR_PARALLEL_MODE': '1d',                     
                              'TENSOR_PARALLEL_SIZE': 1,                        
                              'WARMUP_EPOCHS': 1,                               
                              'WEIGHT_DECAY': 0.3,                              
                              'clip_grad_norm': 1.0,                            
                              'parallel': {'pipeline': 4, 'tensor': {'mode':    
                             '1d', 'size': 1}}}                                 
                             ================================                   

                    INFO     colossalai - colossalai - INFO:                    
                             /workspace/code/ColossalAI/colossalai/initialize.py
                             :277 initialize                                    
                    INFO     colossalai - colossalai - INFO: cuDNN benchmark =  
                             False, deterministic = False                       
c10dDEBUG rank:2 recv, shape:torch.Size([]), device:cuda:2, src:1, group:None, tag:0
                    WARNING  colossalai - colossalai - WARNING:                 
                             /workspace/code/ColossalAI/colossalai/initialize.py
                             :442 initialize                                    
                    WARNING  colossalai - colossalai - WARNING: No PyTorch DDP  
                             or gradient handler is set up, please make sure you
                             do not need to all-reduce the gradients after a    
                             training step.                                     
                    INFO     colossalai - colossalai - INFO: train.py:102 main  
                    INFO     colossalai - colossalai - INFO: Engine is built    
c10dDEBUG rank:0 send, shape:torch.Size([]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:1 recv, shape:torch.Size([]), device:cuda:1, src:0, group:None, tag:0
c10dDEBUG rank:0 send, shape:torch.Size([]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:0 send, shape:torch.Size([3]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:1 recv, shape:torch.Size([3]), device:cuda:1, src:0, group:None, tag:0
c10dDEBUG rank:0 isend, shape:torch.Size([1, 197, 16]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:1 irecv, shape:torch.Size([1, 197, 16]), device:cuda:1, src:0, group:None, tag:0
c10dDEBUG rank:0 isend, shape:torch.Size([1, 197, 16]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:0 isend, shape:torch.Size([1, 197, 16]), device:cuda:0, dst:1, group:None, tag:0
c10dDEBUG rank:0 irecv, shape:torch.Size([1, 197, 16]), device:cuda:0, src:1, group:None, tag:0
c10dDEBUG rank:0 isend, shape:torch.Size([1, 197, 16]), device:cuda:0, dst:1, group:None, tag:0

According to the log, the pipeline successfully transmitted the meta object, but got stuck in the process of sending output from rank 0 to rank 1 using isend. Using top, we found that during the stuck, the CPU usage of two processes was very high:

  PID USER      PR  NI    VIRT    RES    SHR S  %CPU  %MEM     TIME+ COMMAND                                                               
35597 root      20   0   26.6g   2.3g 679720 R 101.3   0.5   1:53.95 python3.8                                                             
35598 root      20   0   25.3g   1.6g 565736 R 101.0   0.3   1:52.47 python3.8 

Interestingly, when we adjust the batch size to 1, the above issue does not occur.

Based on the above information, we guess that there may be some kind of communication deadlock when using both pipeline parallel and gemini at the same time.

Environment

docker image based on nvidia/cuda:11.7.1-devel-ubuntu20.04. The container startup command we use is sudo docker run -itd --ipc=host --network=host --gpus all --hostname colossalai --name colossalai --cap-add=SYS_PTRACE ${image_name} /bin/bash torch 1.13.1

binmakeswell commented 1 year ago

Hi @liuzeming-yuxi Replied in #3403