shentianxiao / language-style-transfer

Apache License 2.0
553 stars 135 forks source link

OOM when using my own dataset #29

Open shizhediao opened 4 years ago

shizhediao commented 4 years ago

Hi, Thanks for your great work. I was really confused that when I am training with my own dataset which has 86217 lines and 7695997 words, ~43MB, I meet with OOM easily. Even if I set batch_size=10. I think there must be some problems. Could you help me to figure it out? Thanks!

Wchenguang commented 4 years ago

Fix 'get_batch()' and 'get_batches()' in utils.py like this:

def get_batch(x, y, word2id, max_seq_length, noisy=False, min_len=5):
    ...
    rev_x, go_x, x_eos, weights = [], [], [], []
    max_len = max([len(sent) for sent in x])
    if(max_len > max_seq_length):
        max_len = max_seq_length
    max_len = max(max_len, min_len)
    for sent in x:
        sent_id = [word2id[w] if w in word2id else unk for w in sent][:max_len]
        l = len(sent_id)
        padding = [pad] * (max_len - l)
        _sent_id = noise(sent_id, unk) if noisy else sent_id
        rev_x.append(padding + _sent_id[::-1])
        go_x.append([go] + sent_id + padding)
        x_eos.append(sent_id + [eos] + padding)
        weights.append([1.0] * (l+1) + [0.0] * (max_len-l))
    ...
def get_batches(x0, x1, word2id, batch_size, max_seq_length, noisy=False):
    ...
    while s < n:
        t = min(s + batch_size, n)
        batches.append(get_batch(x0[s:t] + x1[s:t],
            [0]*(t-s) + [1]*(t-s), word2id, max_seq_length, noisy))
        s = t

    return batches, order0, order1