Closed dangxingyu closed 1 year ago
Here is the code of my opt model with layers from fairscale.nn.model_parallel.layers
:
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairscale
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ParallelEmbedding
import fairscale.nn.model_parallel as mpu
import transformers.models.opt.modeling_opt as huggingface_opt
import fairscale.nn.model_parallel.initialize as fs_init
import transformers.models.opt.modeling_opt as opt
from typing import Optional, Tuple
from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class ModelConfig:
hidden_dim: int = 1024
n_layers: int = 24
n_heads: int = 32
attn_dropout: float = 0.1
ffn_dropout: float = 0.1
# attn_mask: bool = True
enable_bias: bool = True
# init_method = lambda x: x
num_embeddings: int = 50272
embedding_dim: int = 512
max_position_embeddings: int = 2048
norm_eps: float = 1e-5
model_name_or_path: str = 'facebook/opt-350m'
do_layer_norm_before: bool = False
def get_model_config(model_name):
model_config = ModelConfig()
model_config.model_name_or_path = model_name
if model_name == 'facebook/opt-350m':
return model_config
elif model_name == 'facebook/opt-1.3b':
model_config.hidden_dim=2048
model_config.embedding_dim=2048
model_config.do_layer_norm_before=True
return model_config
elif model_name == 'facebook/opt-2.7b':
model_config.hidden_dim=2560
model_config.n_layers=32
model_config.embedding_dim=2560
model_config.do_layer_norm_before=True
class LearnedPositionalEmbedding(ParallelEmbedding):
def __init__(self, config: ModelConfig):
self.offset = 2
super().__init__(
config.max_position_embeddings +
self.offset,
config.hidden_dim,
# init_method=config.init_method
)
def forward(self, attention_mask: torch.LongTensor,
past_key_values_length: int = 0):
attention_mask = attention_mask.long()
positions = (torch.cumsum(attention_mask, dim=1).type_as(
attention_mask) * attention_mask).long() - 1
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
class Attention(nn.Module):
def __init__(self, config: ModelConfig, is_decoder: bool = True):
super(Attention, self).__init__()
self.hidden_dim = config.hidden_dim
self.n_heads = config.n_heads
self.dropout = config.attn_dropout
self.head_dim = self.hidden_dim // self.n_heads
self.n_local_heads = config.n_heads // fs_init.get_model_parallel_world_size()
if self.head_dim * self.n_heads != self.hidden_dim:
raise ValueError(
f"hidden_dim {self.hidden_dim} is not a multiple of n_heads {self.n_heads}"
)
self.scale = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.q_proj = ColumnParallelLinear(
self.hidden_dim,
self.head_dim * self.n_heads,
bias=config.enable_bias,
gather_output=False,
# init_method=config.init_method
)
self.k_proj = ColumnParallelLinear(
self.hidden_dim,
self.head_dim * self.n_heads,
bias=config.enable_bias,
gather_output=False,
# init_method=config.init_method
)
self.v_proj = ColumnParallelLinear(
self.hidden_dim,
self.head_dim * self.n_heads,
bias=config.enable_bias,
gather_output=False,
# init_method=config.init_method
)
self.out_proj = RowParallelLinear(
self.head_dim * self.n_heads,
self.hidden_dim,
bias=config.enable_bias,
input_is_parallel=True,
# init_method=config.init_method
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(
bsz,
seq_len,
self.n_local_heads,
self.head_dim).transpose(
1,
2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
bsz, tgt_len, _ = hidden_states.size()
query_states, key_states, value_states = self.q_proj(
hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
# (bsz, n_local_heads, tgt_len, head_dim)
query_states = self._shape(query_states, -1, bsz)
# (bsz, n_local_heads, src_len, head_dim)
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
src_len = key_states.size(2)
scores = torch.matmul(query_states, key_states.transpose(
2, 3)) / self.scale # (bsz, n_local_heads, tgt_len, src_len)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
scores = scores.masked_fill(
attention_mask == 0, torch.finfo(
hidden_states.dtype).min)
scores = F.softmax(scores.float(), dim=-1).type_as(query_states)
# (bsz, n_local_heads, tgt_len, head_dim)
output = torch.matmul(scores, value_states)
output = output.transpose(1, 2).contiguous().view(bsz, tgt_len, -1)
output = self.out_proj(output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.self_attn = Attention(config, is_decoder=True)
self.attn_dropout = nn.Dropout(config.attn_dropout)
self.ffn_dropout = nn.Dropout(config.ffn_dropout)
# self.self_attn_layer_norm = RMSNorm(
# config.hidden_dim, eps=config.norm_eps)
# self.final_layer_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)
self.final_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)
self.fc1 = ColumnParallelLinear(
config.hidden_dim,
config.hidden_dim * 4,
bias=config.enable_bias,
# init_method=config.init_method,
gather_output=False)
self.fc2 = RowParallelLinear(
config.hidden_dim * 4,
config.hidden_dim,
bias=config.enable_bias,
# init_method=config.init_method,
input_is_parallel=True)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
):
residual = hidden_states
if self.config.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states,
attention_mask=attention_mask,
)
hidden_states = self.attn_dropout(hidden_states)
hidden_states = residual + hidden_states
if not self.config.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
if self.config.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
hidden_states = self.ffn_dropout(hidden_states)
hidden_states = residual + hidden_states
if not self.config.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.n_layers)])
self.embed_tokens = VocabParallelEmbedding(
config.num_embeddings,
config.embedding_dim,
# init_method=config.init_method,
padding_idx=1)
self.embed_positions = LearnedPositionalEmbedding(config)
if config.embedding_dim != config.hidden_dim:
self.project_in = nn.Linear(
config.embedding_dim,
config.hidden_dim,
bias=False)
self.project_out = nn.Linear(
config.hidden_dim,
config.embedding_dim,
bias=False)
else:
self.project_in = self.project_out = None
if config.do_layer_norm_before:
self.final_layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.norm_eps)
def forward(self, input_ids, attention_mask):
bsz, tgt_len = input_ids.size()
x = self.embed_tokens(input_ids)
if self.project_in is not None:
x = self.project_in(x)
positions = self.embed_positions(attention_mask)
x = x + positions
for layer in self.layers:
x = layer(x, attention_mask)
if self.project_out is not None:
x = self.project_out(x)
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x
class OPTModel(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.decoder = OPTDecoder(config)
def forward(self, input_ids, attention_mask):
x = self.decoder(input_ids, attention_mask)
return x
class OPTForCausalLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.model = OPTModel(config)
self.lm_head = nn.Linear(
config.embedding_dim,
config.num_embeddings,
bias=False)
def forward(self, input_ids, attention_mask, labels):
x = self.model(input_ids, attention_mask)
logits = self.lm_head(x)
# print(logits.device, logits.size(), labels.device, labels.size())
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# print(shift_logits, shift_labels)
# print(shift_logits.size(), shift_labels.size())
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# loss_fct = vocab_parallel_cross_entropy
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return (loss, logits) if loss is not None else (logits,)
And here is the code of training procedure:
def train(
args,
model,
rank,
world_size,
train_loader,
optimizer,
epoch,
sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
# tensor parallelism training loop
# only rank 0 will load from the dataloader
# and then broadcast the data to all other ranks
# all other ranks will wait for the data from rank 0
if rank == 0:
for batch in train_loader:
for k, v in batch.items():
batch[k] = v.to(local_rank)
input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
# dist.broadcast(torch.stack([input_ids, attention_mask, labels]), 0)
# dist.barrier()
batch_size = input_ids.shape[0]
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels)
loss = output[0]
loss.backward()
optimizer.step()
optimizer.zero_grad()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += batch_size
inner_pbar.update(1)
if rank == 0:
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
I've also tried PyTorch and Deepspeed FSDP, and I'm able to run the opt-1.3B model on my devices without encountering any memory issues.
Did you try DDP or FSDP from pytorch? I am not familiar with the model parallel code at the moment. I do know that the recently released llama code on github uses model parallel code. Maybe you can checkout the code there. sorry about not able to help much.
Hi Min, I've tried FSDP from PyTorch torch.distributed.fsdp
and it works well! Yeah! I use the llama code as a reference for writing the model with the FairScale model parallel layers, but the llama code is only released for inference, there isn't any training example for model parallelism.
I see. Is FSDP from pytorch not sufficient so that you need to use fairscale's model parallel code?
Yep! Actually, I'm trying to finetune llama, which is indeed based on the fairscale's model parallel code
I'm currently working on distributed training of a large language model and I'm using opt-1.3B with layers from
fairscale.nn.model_parallel.layers
and split checkpoints for loading. However, I'm experiencing unexpected memory consumption during training.I'm using the OSS optimizer to reduce the redundant optimizer state, and I'm only loading the data and running the pipeline on rank 0 since I'm using tensor parallelism. Despite this, I'm encountering CUDA out of memory errors when training with 8 RTX2080Ti GPUs, each with 10GB memory.
I've also tried PyTorch and Deepspeed FSDP, and I'm able to run the opt-1.3B model on my devices without encountering any memory issues.
I'm wondering if there's something wrong with my training procedure or if I've written something wrong with the model in mpu form. Additionally, I would appreciate it if someone could provide some sample training code using the fairscale tensor parallelism framework.
Thank you!