facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.16k stars 277 forks source link

Unexpected Large Memory Consumption during Tensor Parallelism Training with OPT-1.3B #1111

Closed dangxingyu closed 1 year ago

dangxingyu commented 1 year ago

I'm currently working on distributed training of a large language model and I'm using opt-1.3B with layers from fairscale.nn.model_parallel.layers and split checkpoints for loading. However, I'm experiencing unexpected memory consumption during training.

I'm using the OSS optimizer to reduce the redundant optimizer state, and I'm only loading the data and running the pipeline on rank 0 since I'm using tensor parallelism. Despite this, I'm encountering CUDA out of memory errors when training with 8 RTX2080Ti GPUs, each with 10GB memory.

I've also tried PyTorch and Deepspeed FSDP, and I'm able to run the opt-1.3B model on my devices without encountering any memory issues.

I'm wondering if there's something wrong with my training procedure or if I've written something wrong with the model in mpu form. Additionally, I would appreciate it if someone could provide some sample training code using the fairscale tensor parallelism framework.

Thank you!

dangxingyu commented 1 year ago

Here is the code of my opt model with layers from fairscale.nn.model_parallel.layers:

import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairscale
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ParallelEmbedding
import fairscale.nn.model_parallel as mpu
import transformers.models.opt.modeling_opt as huggingface_opt
import fairscale.nn.model_parallel.initialize as fs_init
import transformers.models.opt.modeling_opt as opt
from typing import Optional, Tuple
from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class ModelConfig:
    hidden_dim: int = 1024
    n_layers: int = 24
    n_heads: int = 32
    attn_dropout: float = 0.1
    ffn_dropout: float = 0.1
    # attn_mask: bool = True
    enable_bias: bool = True
    # init_method = lambda x: x
    num_embeddings: int = 50272
    embedding_dim: int = 512
    max_position_embeddings: int = 2048
    norm_eps: float = 1e-5
    model_name_or_path: str = 'facebook/opt-350m'
    do_layer_norm_before: bool = False

def get_model_config(model_name):
    model_config = ModelConfig()
    model_config.model_name_or_path = model_name
    if model_name == 'facebook/opt-350m':
        return model_config
    elif model_name == 'facebook/opt-1.3b':
        model_config.hidden_dim=2048
        model_config.embedding_dim=2048
        model_config.do_layer_norm_before=True
        return model_config
    elif model_name == 'facebook/opt-2.7b':
        model_config.hidden_dim=2560
        model_config.n_layers=32
        model_config.embedding_dim=2560
        model_config.do_layer_norm_before=True

class LearnedPositionalEmbedding(ParallelEmbedding):

    def __init__(self, config: ModelConfig):
        self.offset = 2
        super().__init__(
            config.max_position_embeddings +
            self.offset,
            config.hidden_dim,
            # init_method=config.init_method
        )

    def forward(self, attention_mask: torch.LongTensor,
                past_key_values_length: int = 0):
        attention_mask = attention_mask.long()
        positions = (torch.cumsum(attention_mask, dim=1).type_as(
            attention_mask) * attention_mask).long() - 1
        positions = positions[:, past_key_values_length:]
        return super().forward(positions + self.offset)

class Attention(nn.Module):
    def __init__(self, config: ModelConfig, is_decoder: bool = True):
        super(Attention, self).__init__()
        self.hidden_dim = config.hidden_dim
        self.n_heads = config.n_heads
        self.dropout = config.attn_dropout
        self.head_dim = self.hidden_dim // self.n_heads

        self.n_local_heads = config.n_heads // fs_init.get_model_parallel_world_size()

        if self.head_dim * self.n_heads != self.hidden_dim:
            raise ValueError(
                f"hidden_dim {self.hidden_dim} is not a multiple of n_heads {self.n_heads}"
            )

        self.scale = self.head_dim ** -0.5
        self.is_decoder = is_decoder

        self.q_proj = ColumnParallelLinear(
            self.hidden_dim,
            self.head_dim * self.n_heads,
            bias=config.enable_bias,
            gather_output=False,
            # init_method=config.init_method
        )

        self.k_proj = ColumnParallelLinear(
            self.hidden_dim,
            self.head_dim * self.n_heads,
            bias=config.enable_bias,
            gather_output=False,
            # init_method=config.init_method
        )

        self.v_proj = ColumnParallelLinear(
            self.hidden_dim,
            self.head_dim * self.n_heads,
            bias=config.enable_bias,
            gather_output=False,
            # init_method=config.init_method
        )

        self.out_proj = RowParallelLinear(
            self.head_dim * self.n_heads,
            self.hidden_dim,
            bias=config.enable_bias,
            input_is_parallel=True,
            # init_method=config.init_method
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(
            bsz,
            seq_len,
            self.n_local_heads,
            self.head_dim).transpose(
            1,
            2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        bsz, tgt_len, _ = hidden_states.size()
        query_states, key_states, value_states = self.q_proj(
            hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
        # (bsz, n_local_heads, tgt_len, head_dim)
        query_states = self._shape(query_states, -1, bsz)
        # (bsz, n_local_heads, src_len, head_dim)
        key_states = self._shape(key_states, -1, bsz)
        value_states = self._shape(value_states, -1, bsz)

        src_len = key_states.size(2)
        scores = torch.matmul(query_states, key_states.transpose(
            2, 3)) / self.scale  # (bsz, n_local_heads, tgt_len, src_len)

        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(
                attention_mask == 0, torch.finfo(
                    hidden_states.dtype).min)

        scores = F.softmax(scores.float(), dim=-1).type_as(query_states)
        # (bsz, n_local_heads, tgt_len, head_dim)
        output = torch.matmul(scores, value_states)
        output = output.transpose(1, 2).contiguous().view(bsz, tgt_len, -1)
        output = self.out_proj(output)
        return output

class OPTDecoderLayer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.self_attn = Attention(config, is_decoder=True)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.ffn_dropout = nn.Dropout(config.ffn_dropout)
        # self.self_attn_layer_norm = RMSNorm(
        #     config.hidden_dim, eps=config.norm_eps)
        # self.final_layer_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
        self.self_attn_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)
        self.final_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)
        self.fc1 = ColumnParallelLinear(
            config.hidden_dim,
            config.hidden_dim * 4,
            bias=config.enable_bias,
            # init_method=config.init_method,
            gather_output=False)
        self.fc2 = RowParallelLinear(
            config.hidden_dim * 4,
            config.hidden_dim,
            bias=config.enable_bias,
            # init_method=config.init_method,
            input_is_parallel=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ):
        residual = hidden_states
        if self.config.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
        )
        hidden_states = self.attn_dropout(hidden_states)
        hidden_states = residual + hidden_states
        if not self.config.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
        residual = hidden_states
        if self.config.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
        hidden_states = self.ffn_dropout(hidden_states)
        hidden_states = residual + hidden_states
        if not self.config.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

class OPTDecoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.layers = nn.ModuleList(
            [OPTDecoderLayer(config) for _ in range(config.n_layers)])
        self.embed_tokens = VocabParallelEmbedding(
            config.num_embeddings,
            config.embedding_dim,
            # init_method=config.init_method,
            padding_idx=1)
        self.embed_positions = LearnedPositionalEmbedding(config)
        if config.embedding_dim != config.hidden_dim:
            self.project_in = nn.Linear(
                config.embedding_dim,
                config.hidden_dim,
                bias=False)
            self.project_out = nn.Linear(
                config.hidden_dim,
                config.embedding_dim,
                bias=False)
        else:
            self.project_in = self.project_out = None
        if config.do_layer_norm_before:
            self.final_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)

    def forward(self, input_ids, attention_mask):
        bsz, tgt_len = input_ids.size()
        x = self.embed_tokens(input_ids)
        if self.project_in is not None:
            x = self.project_in(x)
        positions = self.embed_positions(attention_mask)
        x = x + positions
        for layer in self.layers:
            x = layer(x, attention_mask)
        if self.project_out is not None:
            x = self.project_out(x)
        if self.final_layer_norm is not None:
            x = self.final_layer_norm(x)
        return x

class OPTModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.decoder = OPTDecoder(config)

    def forward(self, input_ids, attention_mask):
        x = self.decoder(input_ids, attention_mask)
        return x

class OPTForCausalLM(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.model = OPTModel(config)
        self.lm_head = nn.Linear(
            config.embedding_dim,
            config.num_embeddings,
            bias=False)

    def forward(self, input_ids, attention_mask, labels):
        x = self.model(input_ids, attention_mask)
        logits = self.lm_head(x)
        # print(logits.device, logits.size(), labels.device, labels.size())
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print(shift_logits, shift_labels)
            # print(shift_logits.size(), shift_labels.size())
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            # loss_fct = vocab_parallel_cross_entropy
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return (loss, logits) if loss is not None else (logits,)

And here is the code of training procedure:

def train(
        args,
        model,
        rank,
        world_size,
        train_loader,
        optimizer,
        epoch,
        sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )

    # tensor parallelism training loop
    # only rank 0 will load from the dataloader
    # and then broadcast the data to all other ranks
    # all other ranks will wait for the data from rank 0

    if rank == 0:
        for batch in train_loader:
            for k, v in batch.items():
                batch[k] = v.to(local_rank)
            input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
            # dist.broadcast(torch.stack([input_ids, attention_mask, labels]), 0)
            # dist.barrier()
            batch_size = input_ids.shape[0]
            output = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels)
            loss = output[0]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            fsdp_loss[0] += loss.item()
            fsdp_loss[1] += batch_size
            inner_pbar.update(1)

    if rank == 0:
        train_accuracy = fsdp_loss[0] / fsdp_loss[1]

    if rank == 0:
        inner_pbar.close()
        print(
            f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
        )
    return train_accuracy
min-xu-ai commented 1 year ago

I've also tried PyTorch and Deepspeed FSDP, and I'm able to run the opt-1.3B model on my devices without encountering any memory issues.

Did you try DDP or FSDP from pytorch? I am not familiar with the model parallel code at the moment. I do know that the recently released llama code on github uses model parallel code. Maybe you can checkout the code there. sorry about not able to help much.

dangxingyu commented 1 year ago

Hi Min, I've tried FSDP from PyTorch torch.distributed.fsdp and it works well! Yeah! I use the llama code as a reference for writing the model with the FairScale model parallel layers, but the llama code is only released for inference, there isn't any training example for model parallelism.

min-xu-ai commented 1 year ago

I see. Is FSDP from pytorch not sufficient so that you need to use fairscale's model parallel code?

dangxingyu commented 1 year ago

Yep! Actually, I'm trying to finetune llama, which is indeed based on the fairscale's model parallel code