Closed zhao1402072392 closed 4 years ago
Can you link the code you are using?
WMT14 is considerably larger than Multi30k. It will have a lot more examples, which means each epoch takes longer, and has a larger vocabulary, which means the number of parameters in your model will be considerably greater, again causing each epoch to take longer.
i just use your lesson 1 - Sequence to Sequence Learning with Neural Networks.ipynb,and i change the dataset to WMT'14 as trainig dataset,the newstest2013.en or newstest2013.de as the validation dataset, newstest2014.en or newstest2014.de as the test dataset.dataset's link : https://nlp.stanford.edu/projects/nmt/ print("NOW, preparing the TRAIN dataset !!") train = TranslationDataset(path='../data/train/', exts=('train.en', 'train.de'), fields=(EN, DE)) print("The Train dataset is ready,now preparing the VAL dataset!!") val = TranslationDataset(path='../data/val/', exts=('newstest2013.en', 'newstest2013.de'), fields=(EN, DE)) print("The VAL dataset is ready,now preparing the TEST dataset!!") test = TranslationDataset(path='../data/test/', exts=('newstest2014.en', 'newstest2014.de'), fields=(EN, DE)) print("The TEST dataset is ready,now ") DE.build_vocab(train.src, max_size=10000) EN.build_vocab(train.src, max_size=10000) train_iter, val_iter, test_iter = BucketIterator.splits( (train, val, test), batch_size=batch_size, repeat=False, device=0)
------------------ 原始邮件 ------------------ 发件人: "Ben Trevett"<notifications@github.com>; 发送时间: 2019年11月6日(星期三) 晚上9:31 收件人: "bentrevett/pytorch-seq2seq"<pytorch-seq2seq@noreply.github.com>; 抄送: "赵家胥"<1402072392@qq.com>;"Author"<author@noreply.github.com>; 主题: Re: [bentrevett/pytorch-seq2seq] why training on the WMT'14 dataset is so slow? (#52)
Can you link the code you are using?
WMT14 is considerably larger than Multi30k. It will have a lot more examples, which means each epoch takes longer, and has a larger vocabulary, which means the number of parameters in your model will be considerably greater, again causing each epoch to take longer.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.
I had a similar out memory issue when I tried a bigger dataset (+500K long sentence pairs) for another task.
Even if I try with a batch size of 1, and limited vocab, it runs but before getting to 1 epoch (after some hours) it will OOM (on 1 TITAN X GPU with ~12GB ram). I'm using code from 4th RNN tutorial (changed GRU -> LSTM).
I’ll have a look at this and see if there’s a bug in the code.
i did the same experiment with two different datasets of different sizes.
With the small one (17K sentence pairs), the training for an epoch was quite fast, less than a minute (Although I was OOM if I choose bigger hiperparameters with the same batch size).
For the bigger dataset (540K sentence pairs) I was OOM using small hiperparameters with a batch-size of 1, and a fixed vocab size (source 10-20K, target 1K).
Don't know if there is any bug issue, or we reached the limit of capacity for this implementation in terms of memory and computation eficiency.
So I've had a look at the WMT14 dataset - I believe the issue is that there are some very long sentences in there. There is an example with over 450 tokens in the training set. The Multi30k dataset has an average sentence length of around 15 tokens. When TorchText tries to load a batch with that example it will create a tensor of [450, batch size] on the GPU, which is probably causing the OOM issues.
One interesting thing is that the WMT14 dataset already has a vocabulary, see wmt14/vocab.bpe.32000, thus we can directly load that and avoid creating a vocabulary. Note: this is a shared vocabulary, i.e. it has both English and German in it. However, if you're using the BPE data then you should be using the BPE vocabulary anyway!
We can also tokenize the WMT14 dataset beforehand so it does not have to be tokenized each time the program runs. This will save a lot of time preprocessing each time you want to run the translation code.
I created a tokenize_wmt14.py
file:
import os
import json
from tqdm import tqdm
def file_iterator(path):
with open(path, 'r') as f:
for sentence in f:
yield sentence
def tokenize_data(en_read_path, de_read_path, write_path, max_length):
if os.path.exists(write_path):
print(f'{write_path} already exists!')
return 0
en_file_iterator = file_iterator(en_read_path)
de_file_iterator = file_iterator(de_read_path)
with open(write_path, 'w+') as f:
for en_sent, de_sent in tqdm(zip(en_file_iterator, de_file_iterator)):
en_tokens = en_sent.split()[:max_length]
de_tokens = de_sent.split()[:max_length]
example = {'de': de_tokens, 'en': en_tokens}
json.dump(example, f)
f.write('\n')
max_length = 25
tokenize_data('.data/wmt14/train.tok.clean.bpe.32000.en',
'.data/wmt14/train.tok.clean.bpe.32000.de',
'.data/wmt14/train.tok.clean.bpe.32000.jsonl',
max_length)
tokenize_data('.data/wmt14/newstest2013.tok.bpe.32000.en',
'.data/wmt14/newstest2013.tok.bpe.32000.de',
'.data/wmt14/newstest2013.tok.bpe.32000.jsonl',
max_length)
tokenize_data('.data/wmt14/newstest2014.tok.bpe.32000.en',
'.data/wmt14/newstest2014.tok.bpe.32000.de',
'.data/wmt14/newstest2014.tok.bpe.32000.jsonl',
max_length)
You only need to run this once and it will tokenize the examples (simply splits them on whitespace as they have already been tokenized), cut them to max_length
and save them in a .jsonl format which we can open with TorchText.
Then, when doing the actual translation we can make a tutorial_4_wmt14.py
file with the following:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.data import Field, TabularDataset, BucketIterator
from tqdm import tqdm
import random
import math
import time
import functools
SEED = 1234
BATCH_SIZE = 256
N_EPOCHS = 10
CLIP = 1
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
SOS_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
def load_vocab(path, pad_token, unk_token, sos_token, eos_token):
vocab = {pad_token: 0 , unk_token: 1, sos_token: 2, eos_token: 3}
with open(path, 'r') as f:
for tok in f:
vocab[tok.strip()] = len(vocab)
return vocab
def numericalize(vocab, unk_token, tokens):
idxs = [vocab.get(t, vocab[unk_token]) for t in tokens]
return idxs
print('Loading vocab...')
vocab = load_vocab('.data/wmt14/vocab.bpe.32000',
PAD_TOKEN,
UNK_TOKEN,
SOS_TOKEN,
EOS_TOKEN)
print(f'Vocab has {len(vocab)} items')
numericalizer = functools.partial(numericalize, vocab, UNK_TOKEN)
EN = Field(use_vocab = False,
preprocessing = numericalizer,
init_token = vocab[SOS_TOKEN],
eos_token = vocab[EOS_TOKEN],
pad_token = vocab[PAD_TOKEN],
unk_token = vocab[UNK_TOKEN])
DE = Field(use_vocab = False,
preprocessing = numericalizer,
init_token = vocab[SOS_TOKEN],
eos_token = vocab[EOS_TOKEN],
pad_token = vocab[PAD_TOKEN],
unk_token = vocab[UNK_TOKEN],
include_lengths = True)
fields = {'de': ('de', DE), 'en': ('en', EN)}
train_data, valid_data, test_data = TabularDataset.splits(
path = '.data/wmt14',
train = 'train.tok.clean.bpe.32000.jsonl',
validation = 'newstest2013.tok.bpe.32000.jsonl',
test = 'newstest2014.tok.bpe.32000.jsonl',
format = 'json',
fields = fields
)
print(f'{len(train_data)} training examples')
print(f'{len(valid_data)} validation examples')
print(f'{len(test_data)} test examples')
print(f'Example: {vars(train_data[0])}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
sort_within_batch = True,
sort_key = lambda x : len(x.de),
device = device)
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_len):
embedded = self.dropout(self.embedding(src))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len)
packed_outputs, hidden = self.rnn(packed_embedded)
outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
return outputs, hidden
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super().__init__()
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
self.v = nn.Parameter(torch.rand(dec_hid_dim))
def forward(self, hidden, encoder_outputs, mask):
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
energy = energy.permute(0, 2, 1)
v = self.v.repeat(batch_size, 1).unsqueeze(1)
attention = torch.bmm(v, energy).squeeze(1)
attention = attention.masked_fill(mask == 0, -1e10)
return F.softmax(attention, dim = 1)
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
super().__init__()
self.output_dim = output_dim
self.attention = attention
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
self.out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, encoder_outputs, mask):
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
a = self.attention(hidden, encoder_outputs, mask)
a = a.unsqueeze(1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
weighted = torch.bmm(a, encoder_outputs)
weighted = weighted.permute(1, 0, 2)
rnn_input = torch.cat((embedded, weighted), dim = 2)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
embedded = embedded.squeeze(0)
output = output.squeeze(0)
weighted = weighted.squeeze(0)
output = self.out(torch.cat((output, weighted, embedded), dim = 1))
return output, hidden.squeeze(0), a.squeeze(1)
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, pad_idx, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.pad_idx = pad_idx
self.device = device
def create_mask(self, src):
mask = (src != self.pad_idx).permute(1, 0)
return mask
def forward(self, src, src_len, trg, teacher_forcing_ratio = 0.5):
batch_size = src.shape[1]
max_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
attentions = torch.zeros(max_len, batch_size, src.shape[0]).to(self.device)
encoder_outputs, hidden = self.encoder(src, src_len)
input = trg[0,:]
mask = self.create_mask(src)
for t in range(1, max_len):
output, hidden, attention = self.decoder(input, hidden, encoder_outputs, mask)
outputs[t] = output
attentions[t] = attention
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[t] if teacher_force else top1
return outputs, attentions
INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
PAD_IDX = vocab[PAD_TOKEN]
attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
model = Seq2Seq(enc, dec, PAD_IDX, device).to(device)
def init_weights(m):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.normal_(param.data, mean=0, std=0.01)
else:
nn.init.constant_(param.data, 0)
model.apply(init_weights)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for batch in tqdm(iterator):
src, src_len = batch.de
trg = batch.en
optimizer.zero_grad()
output, attetion = model(src, src_len, trg)
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in tqdm(iterator):
src, src_len = batch.de
trg = batch.en
output, attention = model(src, src_len, trg, 0) #turn off teacher forcing
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut4-model-wmt14.pt')
print(f'Epoch: {epoch+1:02}')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')
model.load_state_dict(torch.load('tut4-model-wmt14.pt'))
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
This will run the notebook/tutorial 4 code on the WMT dataset, using the vocabulary already provided for WMT14. It has a batch size of 256, uses about 8.5GB of GPU memory (so uses about 2/3 of my 1080Ti) and is estimated to take four and half hours per epoch, which I think is reasonable for 4.5million examples(?)
Obviously, if you want more than 25 tokens per sentence this will use more GPU memory and take longer - but if 25 is fine then you can get away with increasing the batch size to reduce time per epoch.
@gorka96 I'm not sure what the issue is with your dataset, but I would suggest looking for incredibly long sentences and either trimming them or getting rid of them.
Hi @bentrevett , my issue is the same that you pointed out. In my small dataset, the max length of a sentence was of 122 words, which is already, longer than what i expected (I'm using 2 concatenated sentences as src). Moreover, in the big dataset (built in an unsupervised way), the max length was of 1455 words. I didn't expect sentences of that length in the dataset. Thank you for the suggestion.
thank you so much , this code really help me a lot, It's very kind of you^_^thanks
but does this '.data/wmt14/vocab.bpe.32000' contain vocab of de an en ?,could you give me the link of the vocab please,i can't find it ,thanks^_^
@zhao1402072392 I got all the files from the download link from the TorchText repo, which is https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8
There should be the vocab.bpe.32000 file in there once extracted. And yes, the vocab was created from both German and English. I've not really come across a shared vocabulary before, but it looks like that's how they did it.
ok,thanks again^_^
------------------ 原始邮件 ------------------ 发件人: "Ben Trevett"<notifications@github.com>; 发送时间: 2019年11月18日(星期一) 晚上7:58 收件人: "bentrevett/pytorch-seq2seq"<pytorch-seq2seq@noreply.github.com>; 抄送: "赵家胥"<1402072392@qq.com>; "Mention"<mention@noreply.github.com>; 主题: Re: [bentrevett/pytorch-seq2seq] why training on the WMT'14 dataset is so slow? (#52)
@zhao1402072392 I got all the files from the download link from the TorchText repo, which is https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8
There should be the vocab.bpe.32000 file in there once extracted. And yes, the vocab was created from both German and English. I've not really come across a shared vocabulary before, but it looks like that's how they did it.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.
I try to train the WMT'14 dataset on 2080Ti . it can run just the batch size = 2 otherwise it will OOM,and the training is too slow,about 20 hours per epoch.i don't konw how to deal with this problem, Could you give me some advise?thanks