Open kmehant opened 4 months ago
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
@lvwerra requesting your opinion.
@kmehant thanks for sharing this feature request. Can you briefly describe why you need this feature? Or why you can't do without this feature? It's undoubtedly an interesting feature to have, but I'm worried about the implementation, which risks adding yet another level of complexity. Have you found a way of implementing it? What elements are affected by the changes?
@qgallouedec thanks for circling back.
In my opinion supporting is not complex. Here is a version implementing this - https://github.com/kmehant/trl/tree/pack-pretok
changes / comparison with main - https://github.com/huggingface/trl/compare/main...kmehant:trl:pack-pretok?expand=1
Steps to try this version
Install trl from my fork
git clone -b pack-pretok https://github.com/kmehant/trl.git
cd trl
pip install .
Sample training code
from trl import SFTTrainer
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
t = tok.encode("We adopted exactly the same architecture and tokenizer as Llama 2.")
d = {"input_ids": [t]*10}
import datasets
data = datasets.Dataset.from_dict(d)
trainer = SFTTrainer(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
train_dataset=data,
max_seq_length=10,
packing=True,
)
trainer.train()
Sample output looks like
{'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'train_loss': 2.972621281941732, 'epoch': 3.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:18<00:00, 6.15s/it]
TrainOutput(global_step=3, training_loss=2.972621281941732, metrics={'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'total_flos': 3351820124160.0, 'train_loss': 2.972621281941732, 'epoch': 3.0})
Thank you. I can raise a PR out of this and add tests as needed.
Thanks! It's actually simpler than I expected.
Can you open a PR?
Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset
?
@qgallouedec Have raised a PR here - https://github.com/huggingface/trl/pull/2011
Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset?
Thanks, included that in the PR.
@qgallouedec any update on this thread? Thanks
At this point, trl returns the dataset as is if the provided dataset has signs of being tokenized already. https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/sft_trainer.py#L503
Additionally, I see the ConstantLengthDataset https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/utils.py#L426 has been written only in support of data that is not pretokenized and it should be possible to extend to pretokenized case as well.
Is there of any interest to support packing for pretokenized datasets? if so, I will be interested to contribute.