microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
33.71k stars 3.96k forks source link

Use Pipeline Parallelism and get stuck in the mid[BUG] #5568

Open HackGiter opened 1 month ago

HackGiter commented 1 month ago

Describe the bug I try to use pipeline parallelism with transformer, but I just found out it will be stuck in the mid

To Reproduce Steps to reproduce the behavior:

#!/usr/bin/env python3

import os
import argparse

import torch
import torch.distributed as dist
from torch import nn

from typing import Any, Optional
from transformers import LlamaForCausalLM, Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, split_dataset
from llamafactory.model import load_model, load_tokenizer
from llamafactory.hparams import ModelArguments, DataArguments, get_train_args

import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.utils import RepeatingLoader
from transformers import DataCollatorForSeq2Seq

from dataclasses import dataclass

class PipelineEmbedLayer(nn.Module):
    def __init__(self, layer) -> None:
        super().__init__()
        self.layer = layer

    def forward(self, x):
        return (self.layer(x[0]),) + x[1:]

class PipelineTransformerLayer(nn.Module):
    def __init__(self, layer) -> None:
        super().__init__()
        self.layer = layer

    def forward(self, x):
        hidden_states, attention_mask = x
        position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
        casual_mask = torch.zeros_like(attention_mask, dtype=hidden_states.dtype)
        casual_mask[attention_mask] = torch.finfo(hidden_states.dtype).min
        return (self.layer(hidden_states, casual_mask, position_ids)[0], attention_mask, )

class PipelineNorm(nn.Module):
    def __init__(self, layer) -> None:
        super().__init__()
        self.layer = layer

    def forward(self, x):
        return (self.layer(x[0]),) + x[1:]

class PipelineLMLayer(nn.Module):
    def __init__(self, layer) -> None:
        super().__init__()
        self.layer = layer

    def forward(self, x):
        return (self.layer(x[0]),) + x[1:]

class CrossEntropyLossPipeline(nn.Module):
    def __init__(self,
                 weight: Optional[torch.Tensor] = None, 
                 size_average=None,
                 ignore_index: int = -100,
                 reduce: Any | None = None,
                 reduction: str = 'mean',
                 label_smoothing: float = 0,
                 ) -> None:
        super().__init__()
        self.weight = weight
        self.size_average = size_average
        self.reduce = reduce
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
    def forward(self, inputs, targets):
        return nn.functional.cross_entropy(inputs[0][..., :-1, :].contiguous().view(-1, inputs[0].shape[-1]), targets[..., 1:].contiguous().view(-1), weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction,
                               label_smoothing=self.label_smoothing)

class GPTPipe(PipelineModule):
    def __init__(self, model, loss_fn, num_stages, partition_method, activation_checkpoint_interval):
        layers = [
            LayerSpec(PipelineEmbedLayer, layer=model.model.embed_tokens),
            *[LayerSpec(PipelineTransformerLayer, layer=layer) for layer in model.model.layers],
            LayerSpec(PipelineNorm, layer=model.model.norm),
            LayerSpec(PipelineLMLayer, layer=model.lm_head)
        ]
        super().__init__(
            layers=layers, 
            loss_fn=loss_fn, 
            num_stages=num_stages, 
            partition_method=partition_method, 
            activation_checkpoint_interval=activation_checkpoint_interval,
            )

def handle_dataset(
        local_rank:int,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        training_args: "Seq2SeqTrainingArguments",
        ):
    dist.barrier()
    if local_rank != 0:
        dist.barrier()

    tokenizer_module = load_tokenizer(model_args)
    dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
    dataset = split_dataset(dataset, data_args, training_args)

    if local_rank == 0:
        dist.barrier()
    return dataset

def get_args():
    parser = argparse.ArgumentParser(description='TRAIN')
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help='local rank passed from distributed launcher')
    parser.add_argument('-s',
                        '--steps',
                        type=int,
                        default=100,
                        help='quit after this many steps')
    parser.add_argument('-p',
                        '--pipeline-parallel-size',
                        type=int,
                        default=4,
                        help='pipeline parallelism')
    parser.add_argument('--backend',
                        type=str,
                        default='nccl',
                        help='distributed backend')
    parser.add_argument('--seed', type=int, default=1138, help='PRNG seed')
    parser = deepspeed.add_config_arguments(parser)
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    unknown_args_dict = {}
    unknown_key = None
    for arg in unknown:
        if arg.startswith('--'):
            if unknown_key is not None:
                unknown_args_dict[unknown_key] = None
            unknown_key = arg.replace('--', '')
        else:
            if unknown_key is not None:
                # 尝试将值转换为 int 或 float
                if arg.isdigit():
                    unknown_args_dict[unknown_key] = int(arg)
                else:
                    try:
                        unknown_args_dict[unknown_key] = float(arg)
                    except ValueError:
                        if arg == 'true':
                            unknown_args_dict[unknown_key] = True
                        elif arg == 'false':
                            unknown_args_dict[unknown_key] = False
                        else:
                            unknown_args_dict[unknown_key] = arg
                unknown_key = None
    if unknown_key is not None:
        unknown_args_dict[unknown_key] = None

    return args, unknown_args_dict

def train_base(args):
    torch.manual_seed(args.seed)

    # VGG also works :-)
    #net = vgg19(num_classes=10)
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

    dataset = handle_dataset(args.local_rank, model_args, data_args, training_args)

    engine, optimizer, dataloader, __ = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=[p for p in model.parameters() if p.requires_grad],
        training_data=dataset['train_dataset'])

    dataloader = RepeatingLoader(dataloader)
    data_iter = iter(dataloader)

    rank = dist.get_rank()
    gas = engine.gradient_accumulation_steps()

    criterion = torch.nn.CrossEntropyLoss()

    total_steps = args.steps * gas
    step = 0
    for micro_step in range(total_steps):
        batch = next(data_iter)
        inputs = batch[0].to(engine.device)
        labels = batch[1].to(engine.device)

        outputs = engine(inputs)
        loss = criterion(outputs, labels)
        engine.backward(loss)
        engine.step()

        if micro_step % gas == 0:
            step += 1
            if rank == 0 and (step % 10 == 0):
                print(f'step: {step:3d} / {args.steps:3d} loss: {loss}')

def join_layers(vision_model):
    layers = [
        *vision_model.features,
        vision_model.avgpool,
        lambda x: torch.flatten(x, 1),
        *vision_model.classifier,
    ]
    return layers

def join_layers_for_llama(model):
    layers = [
        PipelineEmbedLayer(model.model.embed_tokens),
        *[PipelineTransformerLayer(layer) for layer in model.model.layers],
        PipelineNorm(model.model.norm),
        PipelineLMLayer(model.lm_head), 
    ]
    return layers

@dataclass
class DataCollator:
    data_collator:DataCollatorForSeq2Seq
    dtype: torch.dtype = torch.bfloat16
    def __call__(self, features, **kwargs):
        inputs = self.data_collator(features, **kwargs)
        input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
        device = inputs['input_ids'].device
        bsz, seq_len = inputs['input_ids'].shape
        causal_mask = torch.full((seq_len, seq_len), fill_value=True, dtype=torch.bool, device=device)
        if seq_len != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask = causal_mask[None, None, :, :].expand(bsz, 1, -1, -1)
        causal_mask = causal_mask.clone()
        if attention_mask.dim() == 2:
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[..., :mask_length]
            padding_mask[attention_mask[:, None, None, :].expand(-1, -1, padding_mask.shape[-2], -1) == 0] = True
            causal_mask[..., :mask_length] = padding_mask
        return (input_ids, causal_mask, ), inputs['labels']

IGNORE_INDEX = -100
def train_pipe(args, unknown, part='parameters'):
    torch.manual_seed(args.seed)
    deepspeed.runtime.utils.set_random_seed(args.seed)
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(unknown)
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    from transformers import LlamaForCausalLM, LlamaConfig
    config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
    config.num_hidden_layers = 4
    model = LlamaForCausalLM(config)
    dataset = handle_dataset(args.local_rank, model_args, data_args, training_args)
    print(args)
    print(len(dataset['train_dataset']))

    model = GPTPipe(
        model=model,
        loss_fn=CrossEntropyLossPipeline(),
        num_stages=args.pipeline_parallel_size,
        # partition_method='type:transformer',
        partition_method='parameters',
        activation_checkpoint_interval=0
    )

    if dist.get_rank() == 0:
        # print(dataset['train_dataset'][0])
        pass
    tokenizer.pad_token = tokenizer.eos_token
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None,  # for shift short attention
        label_pad_token_id=IGNORE_INDEX,
    )
    data_collator = DataCollator(data_collator)

    if dist.get_rank() == 0:
        for name, p in model.named_parameters():
            print(name, p.requires_grad)

    engine, _, _, _ = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=[p for p in model.parameters() if p.requires_grad],
        training_data=dataset['train_dataset'],
        collate_fn=data_collator)

    engine.set_has_attention_mask(True)
    for i in range(args.steps):
        loss = engine.train_batch()
        print(i, loss)

if __name__ == '__main__':
    args, unknown = get_args()

    deepspeed.init_distributed(dist_backend=args.backend)
    args.local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(args.local_rank)

    if dist.get_rank() == 0:
        print(args)

    if args.pipeline_parallel_size == 0:
        train_base(args, unknown)
    else:
        train_pipe(args, unknown)

Expected behavior Exit successfully

Additional Information I just check out the proceeding, and I found that it will get stuck in LoadMicrobatch in PipelineEngine.

System info (please complete the following information):

Launcher context

!/bin/bash

NNODES=1 deepspeed \ --include localhost:4,5,6,7 \ general_distributed_training.py \ --model_name_or_path /data0/pretrained-models/deepseek-llm-7b-base \ --stage sft \ --do_train true \ --finetuning_type full \ --output_dir saves/llama3-8b/full/sft \ --logging_steps 100 \ --save_steps 500 \ --plot_loss true \ --overwrite_output_dir true \ --deepspeed \ --deepspeed_config general_distributed_training.json \ --dataset_dir /home/hzli/LLaMA-Factory/data \ --dataset identity,alpaca_gpt4_en \ --template llama3 \ --cutoff_len 2048 \ --max_samples 1000 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8 \ --bf16 true \ --learning_rate 0.0001 \ --num_train_epochs 3.0 \ --lr_scheduler_type cosine \ --val_size 0.1 \ --per_device_eval_batch_size 1 \ --evaluation_strategy steps \ --eval_steps 500 \ --disable_tqdm false \ --p 4 \ --steps 1000 \

Yuhanleeee commented 4 days ago

Thanks for your effort, may I ask how to solve it, I have the same problem.

HackGiter commented 3 days ago

Thanks for your effort, may I ask how to solve it, I have the same problem.

Actually I change my model and use different dataset to solve that. And it still happens occasionally.