Open shizhediao opened 4 years ago
Fix 'get_batch()' and 'get_batches()' in utils.py like this:
def get_batch(x, y, word2id, max_seq_length, noisy=False, min_len=5):
...
rev_x, go_x, x_eos, weights = [], [], [], []
max_len = max([len(sent) for sent in x])
if(max_len > max_seq_length):
max_len = max_seq_length
max_len = max(max_len, min_len)
for sent in x:
sent_id = [word2id[w] if w in word2id else unk for w in sent][:max_len]
l = len(sent_id)
padding = [pad] * (max_len - l)
_sent_id = noise(sent_id, unk) if noisy else sent_id
rev_x.append(padding + _sent_id[::-1])
go_x.append([go] + sent_id + padding)
x_eos.append(sent_id + [eos] + padding)
weights.append([1.0] * (l+1) + [0.0] * (max_len-l))
...
def get_batches(x0, x1, word2id, batch_size, max_seq_length, noisy=False):
...
while s < n:
t = min(s + batch_size, n)
batches.append(get_batch(x0[s:t] + x1[s:t],
[0]*(t-s) + [1]*(t-s), word2id, max_seq_length, noisy))
s = t
return batches, order0, order1
Hi, Thanks for your great work. I was really confused that when I am training with my own dataset which has 86217 lines and 7695997 words, ~43MB, I meet with OOM easily. Even if I set batch_size=10. I think there must be some problems. Could you help me to figure it out? Thanks!