huggingface / transfer-learning-conv-ai

🦄 State-of-the-Art Conversational AI with Transfer Learning
MIT License
1.74k stars 430 forks source link

Batchwise padding dataset #121

Open mrghofrani opened 2 years ago

mrghofrani commented 2 years ago

Hello I'm pretty new to Pytorch so sorry if this question was so simple. Because of memory limits, I can't pad my dataset as a whole. So I was wondering what is the simplest way to move the pad_dataset function into the training process, I mean how can I pad the dataset in a batch? For ease of reference, I added the pad_dataset below. Thanks.

def pad_dataset(dataset, padding=0):
    """ Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """
    max_l = max(len(x) for x in dataset["input_ids"])
    for name in PADDED_INPUTS:
        dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]]
    return dataset