bentrevett / pytorch-seq2seq

Tutorials on implementing a few sequence-to-sequence (seq2seq) models with PyTorch and TorchText.
MIT License
5.36k stars 1.33k forks source link

why training on the WMT'14 dataset is so slow? #52

Closed zhao1402072392 closed 4 years ago

zhao1402072392 commented 4 years ago

I try to train the WMT'14 dataset on 2080Ti . it can run just the batch size = 2 otherwise it will OOM,and the training is too slow,about 20 hours per epoch.i don't konw how to deal with this problem, Could you give me some advise?thanks

bentrevett commented 4 years ago

Can you link the code you are using?

WMT14 is considerably larger than Multi30k. It will have a lot more examples, which means each epoch takes longer, and has a larger vocabulary, which means the number of parameters in your model will be considerably greater, again causing each epoch to take longer.

zhao1402072392 commented 4 years ago

i just use your lesson 1 - Sequence to Sequence Learning with Neural Networks.ipynb,and i change the dataset to WMT'14 as trainig dataset,the newstest2013.en or newstest2013.de as the validation dataset, newstest2014.en or newstest2014.de as the test dataset.dataset's link : https://nlp.stanford.edu/projects/nmt/ print("NOW, preparing the TRAIN dataset !!") train = TranslationDataset(path='../data/train/', exts=('train.en', 'train.de'), fields=(EN, DE)) print("The Train dataset is ready,now preparing the VAL dataset!!") val = TranslationDataset(path='../data/val/', exts=('newstest2013.en', 'newstest2013.de'), fields=(EN, DE)) print("The VAL dataset is ready,now preparing the TEST dataset!!") test = TranslationDataset(path='../data/test/', exts=('newstest2014.en', 'newstest2014.de'), fields=(EN, DE)) print("The TEST dataset is ready,now ") DE.build_vocab(train.src, max_size=10000) EN.build_vocab(train.src, max_size=10000) train_iter, val_iter, test_iter = BucketIterator.splits( (train, val, test), batch_size=batch_size, repeat=False, device=0)

------------------ 原始邮件 ------------------ 发件人: "Ben Trevett"<notifications@github.com>; 发送时间: 2019年11月6日(星期三) 晚上9:31 收件人: "bentrevett/pytorch-seq2seq"<pytorch-seq2seq@noreply.github.com>; 抄送: "赵家胥"<1402072392@qq.com>;"Author"<author@noreply.github.com>; 主题: Re: [bentrevett/pytorch-seq2seq] why training on the WMT'14 dataset is so slow? (#52)

Can you link the code you are using?

WMT14 is considerably larger than Multi30k. It will have a lot more examples, which means each epoch takes longer, and has a larger vocabulary, which means the number of parameters in your model will be considerably greater, again causing each epoch to take longer.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

GorkaUrbizu commented 4 years ago

I had a similar out memory issue when I tried a bigger dataset (+500K long sentence pairs) for another task.

Even if I try with a batch size of 1, and limited vocab, it runs but before getting to 1 epoch (after some hours) it will OOM (on 1 TITAN X GPU with ~12GB ram). I'm using code from 4th RNN tutorial (changed GRU -> LSTM).

bentrevett commented 4 years ago

I’ll have a look at this and see if there’s a bug in the code.

GorkaUrbizu commented 4 years ago

i did the same experiment with two different datasets of different sizes.

With the small one (17K sentence pairs), the training for an epoch was quite fast, less than a minute (Although I was OOM if I choose bigger hiperparameters with the same batch size).

For the bigger dataset (540K sentence pairs) I was OOM using small hiperparameters with a batch-size of 1, and a fixed vocab size (source 10-20K, target 1K).

Don't know if there is any bug issue, or we reached the limit of capacity for this implementation in terms of memory and computation eficiency.

bentrevett commented 4 years ago

So I've had a look at the WMT14 dataset - I believe the issue is that there are some very long sentences in there. There is an example with over 450 tokens in the training set. The Multi30k dataset has an average sentence length of around 15 tokens. When TorchText tries to load a batch with that example it will create a tensor of [450, batch size] on the GPU, which is probably causing the OOM issues.

One interesting thing is that the WMT14 dataset already has a vocabulary, see wmt14/vocab.bpe.32000, thus we can directly load that and avoid creating a vocabulary. Note: this is a shared vocabulary, i.e. it has both English and German in it. However, if you're using the BPE data then you should be using the BPE vocabulary anyway!

We can also tokenize the WMT14 dataset beforehand so it does not have to be tokenized each time the program runs. This will save a lot of time preprocessing each time you want to run the translation code.

I created a tokenize_wmt14.py file:

import os
import json
from tqdm import tqdm

def file_iterator(path):
    with open(path, 'r') as f:
        for sentence in f:
            yield sentence

def tokenize_data(en_read_path, de_read_path, write_path, max_length):

    if os.path.exists(write_path):
        print(f'{write_path} already exists!')
        return 0

    en_file_iterator = file_iterator(en_read_path)
    de_file_iterator = file_iterator(de_read_path)

    with open(write_path, 'w+') as f:
        for en_sent, de_sent in tqdm(zip(en_file_iterator, de_file_iterator)):

            en_tokens = en_sent.split()[:max_length]
            de_tokens = de_sent.split()[:max_length]

            example = {'de': de_tokens, 'en': en_tokens}
            json.dump(example, f)
            f.write('\n')

max_length = 25

tokenize_data('.data/wmt14/train.tok.clean.bpe.32000.en',
              '.data/wmt14/train.tok.clean.bpe.32000.de',
              '.data/wmt14/train.tok.clean.bpe.32000.jsonl',
              max_length)

tokenize_data('.data/wmt14/newstest2013.tok.bpe.32000.en',
              '.data/wmt14/newstest2013.tok.bpe.32000.de',
              '.data/wmt14/newstest2013.tok.bpe.32000.jsonl',
              max_length)

tokenize_data('.data/wmt14/newstest2014.tok.bpe.32000.en',
              '.data/wmt14/newstest2014.tok.bpe.32000.de',
              '.data/wmt14/newstest2014.tok.bpe.32000.jsonl',
              max_length)

You only need to run this once and it will tokenize the examples (simply splits them on whitespace as they have already been tokenized), cut them to max_length and save them in a .jsonl format which we can open with TorchText.

Then, when doing the actual translation we can make a tutorial_4_wmt14.py file with the following:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.data import Field, TabularDataset, BucketIterator

from tqdm import tqdm

import random
import math
import time
import functools

SEED = 1234
BATCH_SIZE = 256
N_EPOCHS = 10
CLIP = 1
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
SOS_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'

random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

def load_vocab(path, pad_token, unk_token, sos_token, eos_token):
    vocab = {pad_token: 0 , unk_token: 1, sos_token: 2, eos_token: 3}
    with open(path, 'r') as f:
        for tok in f:
            vocab[tok.strip()] = len(vocab)
    return vocab

def numericalize(vocab, unk_token, tokens):
    idxs = [vocab.get(t, vocab[unk_token]) for t in tokens]
    return idxs

print('Loading vocab...')
vocab = load_vocab('.data/wmt14/vocab.bpe.32000', 
                   PAD_TOKEN, 
                   UNK_TOKEN, 
                   SOS_TOKEN, 
                   EOS_TOKEN)

print(f'Vocab has {len(vocab)} items')

numericalizer = functools.partial(numericalize, vocab, UNK_TOKEN)

EN = Field(use_vocab = False,
           preprocessing = numericalizer,
           init_token = vocab[SOS_TOKEN],
           eos_token = vocab[EOS_TOKEN],
           pad_token = vocab[PAD_TOKEN],
           unk_token = vocab[UNK_TOKEN])

DE = Field(use_vocab = False,
           preprocessing = numericalizer,
           init_token = vocab[SOS_TOKEN],
           eos_token = vocab[EOS_TOKEN],
           pad_token = vocab[PAD_TOKEN],
           unk_token = vocab[UNK_TOKEN],
           include_lengths = True)

fields = {'de': ('de', DE), 'en': ('en', EN)}

train_data, valid_data, test_data = TabularDataset.splits(
                                        path = '.data/wmt14',
                                        train = 'train.tok.clean.bpe.32000.jsonl',
                                        validation = 'newstest2013.tok.bpe.32000.jsonl',
                                        test = 'newstest2014.tok.bpe.32000.jsonl',
                                        format = 'json',
                                        fields = fields
)

print(f'{len(train_data)} training examples')
print(f'{len(valid_data)} validation examples')
print(f'{len(test_data)} test examples')

print(f'Example: {vars(train_data[0])}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     sort_within_batch = True,
     sort_key = lambda x : len(x.de),
     device = device)

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_len):

        embedded = self.dropout(self.embedding(src))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len)
        packed_outputs, hidden = self.rnn(packed_embedded)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))

        return outputs, hidden

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Parameter(torch.rand(dec_hid_dim))

    def forward(self, hidden, encoder_outputs, mask):

        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]

        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        energy = energy.permute(0, 2, 1)
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        attention = torch.bmm(v, energy).squeeze(1)
        attention = attention.masked_fill(mask == 0, -1e10)

        return F.softmax(attention, dim = 1)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs, mask):

        input = input.unsqueeze(0)

        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden, encoder_outputs, mask)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        output = self.out(torch.cat((output, weighted, embedded), dim = 1))

        return output, hidden.squeeze(0), a.squeeze(1)

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device

    def create_mask(self, src):
        mask = (src != self.pad_idx).permute(1, 0)
        return mask

    def forward(self, src, src_len, trg, teacher_forcing_ratio = 0.5):

        batch_size = src.shape[1]
        max_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
        attentions = torch.zeros(max_len, batch_size, src.shape[0]).to(self.device)
        encoder_outputs, hidden = self.encoder(src, src_len)
        input = trg[0,:]
        mask = self.create_mask(src)

        for t in range(1, max_len):

            output, hidden, attention = self.decoder(input, hidden, encoder_outputs, mask)
            outputs[t] = output
            attentions[t] = attention
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1) 
            input = trg[t] if teacher_force else top1

        return outputs, attentions

INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
PAD_IDX = vocab[PAD_TOKEN]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, PAD_IDX, device).to(device)

def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters())

criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

def train(model, iterator, optimizer, criterion, clip):

    model.train()

    epoch_loss = 0

    for batch in tqdm(iterator):

        src, src_len = batch.de
        trg = batch.en

        optimizer.zero_grad()
        output, attetion = model(src, src_len, trg)
        output = output[1:].view(-1, output.shape[-1])
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for batch in tqdm(iterator):

            src, src_len = batch.de
            trg = batch.en

            output, attention = model(src, src_len, trg, 0) #turn off teacher forcing
            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut4-model-wmt14.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

model.load_state_dict(torch.load('tut4-model-wmt14.pt'))

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

This will run the notebook/tutorial 4 code on the WMT dataset, using the vocabulary already provided for WMT14. It has a batch size of 256, uses about 8.5GB of GPU memory (so uses about 2/3 of my 1080Ti) and is estimated to take four and half hours per epoch, which I think is reasonable for 4.5million examples(?)

Obviously, if you want more than 25 tokens per sentence this will use more GPU memory and take longer - but if 25 is fine then you can get away with increasing the batch size to reduce time per epoch.

@gorka96 I'm not sure what the issue is with your dataset, but I would suggest looking for incredibly long sentences and either trimming them or getting rid of them.

GorkaUrbizu commented 4 years ago

Hi @bentrevett , my issue is the same that you pointed out. In my small dataset, the max length of a sentence was of 122 words, which is already, longer than what i expected (I'm using 2 concatenated sentences as src). Moreover, in the big dataset (built in an unsupervised way), the max length was of 1455 words. I didn't expect sentences of that length in the dataset. Thank you for the suggestion.

zhao1402072392 commented 4 years ago

thank you so much , this code really help me a lot, It's very kind of you^_^thanks

zhao1402072392 commented 4 years ago

but does this '.data/wmt14/vocab.bpe.32000' contain vocab of de an en ?,could you give me the link of the vocab please,i can't find it ,thanks^_^

bentrevett commented 4 years ago

@zhao1402072392 I got all the files from the download link from the TorchText repo, which is https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8

There should be the vocab.bpe.32000 file in there once extracted. And yes, the vocab was created from both German and English. I've not really come across a shared vocabulary before, but it looks like that's how they did it.

zhao1402072392 commented 4 years ago

ok,thanks again^_^

------------------ 原始邮件 ------------------ 发件人: "Ben Trevett"<notifications@github.com>; 发送时间: 2019年11月18日(星期一) 晚上7:58 收件人: "bentrevett/pytorch-seq2seq"<pytorch-seq2seq@noreply.github.com>; 抄送: "赵家胥"<1402072392@qq.com>; "Mention"<mention@noreply.github.com>; 主题: Re: [bentrevett/pytorch-seq2seq] why training on the WMT'14 dataset is so slow? (#52)

@zhao1402072392 I got all the files from the download link from the TorchText repo, which is https://drive.google.com/uc?export=download&amp;id=0B_bZck-ksdkpM25jRUN2X2UxMm8

There should be the vocab.bpe.32000 file in there once extracted. And yes, the vocab was created from both German and English. I've not really come across a shared vocabulary before, but it looks like that's how they did it.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.