huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.01k forks source link

How can I control data feeding order to model using Huggingface Trainer? #18862

Closed SangwonPark0211 closed 2 years ago

SangwonPark0211 commented 2 years ago

Feature request

I want to train model in the order in which the data are stored.
For example, if there are 100 data, then I want to feed 1st, 2nd data together(because I set batch_size=2 in code) and then 3rd, 4th data and then 5th, 6th data together and so on....
But huggingface Trainer train model using datacollator and this feed data to model randomly by the parameter data_seed.
How can I train model feeding data in the order in which the data are stored?

# load tokenizer
model_checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# load model
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# make batch
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

batch_size = 2
epochs = 3
args = Seq2SeqTrainingArguments(
    output_dir = "saved_model",
    overwrite_output_dir = True,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    num_train_epochs=epochs,
    predict_with_generate=True,
    fp16=False,
    dataloader_num_workers=8,
)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Motivation

I want to control data feeding order to the model.

Your contribution

I want to control data feeding order to the model.

timbmg commented 2 years ago

You can subclass the Seq2SeqTrainer and override the _get_train_sampler method. Instead of creating a RandomSampler object, create a SequentialSampler.

from transformers.trainer_seq2seq import Seq2SeqTrainer
from torch.utils.data import SequentialSampler

class SequentialSeq2SeqTrainer(Seq2SeqTrainer):
    def _get_train_sampler(self) -> SequentialSampler:
        return SequentialSampler(self.train_dataset)
SangwonPark0211 commented 2 years ago

Thank you!! I'll try as you mentioned.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.