pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
79.87k stars 21.48k forks source link

Pytorch Transformer Distributed Training #116400

Open aminaqi opened 6 months ago

aminaqi commented 6 months ago

🐛 Describe the bug

code:

from torchtext.vocab import build_vocab_from_iterator
import torchtext
from typing import Iterable, List
import random
import os
import torch
from tqdm import tqdm
import string
import json
import unicodedata
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

# Set up distributed training if using multiple GPUs
if torch.cuda.device_count() > 1:
    print('MULTIPLE GPUS !!!')

all_letters = string.ascii_lowercase + string.ascii_uppercase + string.punctuation +' ' + 'م' + '\t'
in_vocab = 'دشپ' + all_letters
out_vocab = 'دشپ' + all_letters
PAD_IDX = 0
START_TOKEN_IDX = 1
END_TOKEN_IDX = 2
BATCH_SIZE = 64

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

data = []
for _ in range(100000):
  inputs = [random.randint(1, 999) for _ in range(random.randint(5, 10))]
  target_query = str(sum(inputs))
  input_query = '+'.join(list(map(lambda x: str(x), inputs)))
  data.append((input_query, target_query))

data_test = []
for _ in range(1000):
  inputs = [random.randint(1, 999) for _ in range(random.randint(5, 10))]
  target_query = str(sum(inputs))
  input_query = '+'.join(list(map(lambda x: str(x), inputs)))
  data_test.append((input_query, target_query))

num_train = len(data) - int(0.15 * len(data))
train_data = data[:num_train]
valid_data = data[num_train:]
print("Total Data :", len(data))
print("Train Data :", len(train_data))
print("Valid Data :", len(valid_data))
print("Test Data : ", len(data_test))

def token_transform(input_str):
    return [ch for ch in input_str]

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, index: int) -> List[str]:

    for data_sample in data_iter:
        yield token_transform(data_sample[index])

input_vocab = build_vocab_from_iterator(yield_tokens(data, 0),
                                                    min_freq=1,
                                                    specials=['د', 'ش', 'پ'],
                                                    special_first=True)

output_vocab = build_vocab_from_iterator(yield_tokens(data, 0),
                                                    min_freq=1,
                                                    specials=['د', 'ش', 'پ'],
                                                    special_first=True)

input_vocab.set_default_index(0)
output_vocab.set_default_index(0)

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

torch.manual_seed(0)

SRC_VOCAB_SIZE = len(input_vocab)
TGT_VOCAB_SIZE = len(output_vocab)
EMB_SIZE = 1024
NHEAD = 8
FFN_HID_DIM = 2048
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

if torch.cuda.device_count() > 1:
    transformer = nn.DataParallel(transformer)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([START_TOKEN_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([END_TOKEN_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}

text_transform["INPUT"] = sequential_transforms(token_transform, #Tokenization
                                               input_vocab, #Numericalization
                                               tensor_transform)

text_transform["OUTPUT"] = sequential_transforms(token_transform, #Tokenization
                                               output_vocab, #Numericalization
                                               tensor_transform)

# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform['INPUT'](src_sample))
        tgt_batch.append(text_transform['OUTPUT'](tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    #[5071 * BATCH_SIZE:]
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    iter = tqdm(train_dataloader)
    accuracy = 0
    for step, (src, tgt) in enumerate(iter):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]

        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1).type(torch.long))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        acc = (logits.argmax(2) == tgt_out).sum().item() / (tgt_out.size(1) * tgt_out.size(0))
        accuracy += acc
        #print(loss)
        iter.set_postfix(train_loss=losses / (step + 1), accuracy=acc)

    return losses / len(list(train_dataloader)), accuracy / len(list(train_dataloader))

def evaluate(model):
    model.eval()
    losses = 0
    total_acc = 0

    val_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1).long())
        losses += loss.item()
        total_acc += (logits.argmax(2) == tgt_out).sum().item() / (tgt_out.size(1) * tgt_out.size(0))

    return losses / len(list(val_dataloader)), total_acc /  len(list(val_dataloader))

from timeit import default_timer as timer
NUM_EPOCHS = 180
best_loss = 10
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss, train_acc = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss, val_acc = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}, Val loss: {val_loss:.3f}, Val Acc: {val_acc:.2f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

error ` RuntimeError Traceback (most recent call last) Cell In[2], line 6 4 for epoch in range(1, NUM_EPOCHS+1): 5 start_time = timer() ----> 6 train_loss, train_acc = train_epoch(transformer, optimizer) 7 end_time = timer() 8 val_loss, val_acc = evaluate(transformer)

Cell In[1], line 278, in train_epoch(model, optimizer) 274 tgt_input = tgt[:-1, :] 276 src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input) --> 278 logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) 280 optimizer.zero_grad() 282 tgt_out = tgt[1:, :]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, *kwargs) 169 return self.module(inputs[0], **kwargs[0]) 170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) --> 171 outputs = self.parallel_apply(replicas, inputs, kwargs) 172 return self.gather(outputs, self.output_device)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs) 180 def parallel_apply(self, replicas, inputs, kwargs): --> 181 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices) 87 output = results[i] 88 if isinstance(output, ExceptionWrapper): ---> 89 output.reraise() 90 outputs.append(output) 91 return outputs

File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self) 640 except TypeError: 641 # If the exception takes multiple arguments, don't try to 642 # instantiate since we don't know how to 643 raise RuntimeError(msg) from None --> 644 raise exception

RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(*input, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/tmp/ipykernel_1768/3569133796.py", line 161, in forward outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 145, in forward memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 315, in forward output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 591, in forward x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 599, in _sa_block x = self.self_attn(x, x, x, File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1205, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5251, in multi_head_attention_forward raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") RuntimeError: The shape of the 2D attn_mask is torch.Size([21, 41]), but should be (21, 21). `

Versions

pytorch 2.0.1 cuda 11.7

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

aminaqi commented 6 months ago

currently using this for distributed training and i get error on attention_masks


transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

if torch.cuda.device_count() > 1:
    transformer = nn.DataParallel(transformer)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)
`
awgu commented 6 months ago

Could you try using DistributedDataParallel from torch.nn.parallel.distributed instead of nn.DataParallel?

https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

aminaqi commented 5 months ago

Could you try using DistributedDataParallel from torch.nn.parallel.distributed instead of nn.DataParallel?

https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

what should i set for master_addr and master_port?

awgu commented 5 months ago

@aminaqi Could you try taking a look at https://pytorch.org/tutorials/intermediate/ddp_tutorial.html?

If you are running on a single host, then using torchrun might be the simplest option.