huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.07k stars 1.28k forks source link

Support packing for pretokenized datasets #1848

Open kmehant opened 4 months ago

kmehant commented 4 months ago

At this point, trl returns the dataset as is if the provided dataset has signs of being tokenized already. https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/sft_trainer.py#L503

Additionally, I see the ConstantLengthDataset https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/utils.py#L426 has been written only in support of data that is not pretokenized and it should be possible to extend to pretokenized case as well.

Is there of any interest to support packing for pretokenized datasets? if so, I will be interested to contribute.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

kmehant commented 3 months ago

@lvwerra requesting your opinion.

qgallouedec commented 2 months ago

@kmehant thanks for sharing this feature request. Can you briefly describe why you need this feature? Or why you can't do without this feature? It's undoubtedly an interesting feature to have, but I'm worried about the implementation, which risks adding yet another level of complexity. Have you found a way of implementing it? What elements are affected by the changes?

kmehant commented 2 months ago

@qgallouedec thanks for circling back.

In my opinion supporting is not complex. Here is a version implementing this - https://github.com/kmehant/trl/tree/pack-pretok

changes / comparison with main - https://github.com/huggingface/trl/compare/main...kmehant:trl:pack-pretok?expand=1

Steps to try this version

Install trl from my fork

git clone -b pack-pretok https://github.com/kmehant/trl.git
cd trl 
pip install .

Sample training code

from trl import SFTTrainer
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
t = tok.encode("We adopted exactly the same architecture and tokenizer as Llama 2.")
d = {"input_ids": [t]*10}
import datasets
data = datasets.Dataset.from_dict(d)
trainer = SFTTrainer(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    train_dataset=data,
    max_seq_length=10,
    packing=True,
)
trainer.train()

Sample output looks like

{'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'train_loss': 2.972621281941732, 'epoch': 3.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:18<00:00,  6.15s/it]
TrainOutput(global_step=3, training_loss=2.972621281941732, metrics={'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'total_flos': 3351820124160.0, 'train_loss': 2.972621281941732, 'epoch': 3.0})

Thank you. I can raise a PR out of this and add tests as needed.

qgallouedec commented 2 months ago

Thanks! It's actually simpler than I expected. Can you open a PR? Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset?

kmehant commented 2 months ago

@qgallouedec Have raised a PR here - https://github.com/huggingface/trl/pull/2011

Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset?

Thanks, included that in the PR.

kmehant commented 2 months ago

@qgallouedec any update on this thread? Thanks