pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.2k stars 405 forks source link

Dataset class for pre-tokenized data with `input_ids` and `labels` keys. #1397

Open tatsuya-flip opened 2 months ago

tatsuya-flip commented 2 months ago

Is there a way to bring in a custom dataset that already has input_ids and labels?

I am looking at torchtune.datasets classes. It does not look like that kind of customized training is supported.

Ideally, we would like to bring input_ids and labels as series of integers stored in a json file and a trainer load the dataset from it. In that way, we can take advantage of cross entropy ignore index and hopefully train our model with specific token sequences that are unique to our dataset. We want the low level access.

Hugging face trainer (i.e. SFTTrainer) allows us to use data_collator=default_data_collator. In that way, we were able to use input_ids and labels directly to compute logits.

Thank you

RdoubleA commented 2 months ago

Hi @tatsuya-flip, thanks for the question. This is an interesting use case, the closest thing we have is torchtune.datasets.TextCompletionDataset, which has minimal pre-processing, although the labels are defined as an offset copy of the input_ids. You could make this work with the least amount of changes by copying the TextCompletionDataset class and customizing it slightly:

class RawDataset(Dataset):
    """
    Load directly from HF Hub or a local file without any processing.
    """

    def __init__(
        self,
        source: str,
        **load_dataset_kwargs: Dict[str, Any],
    ) -> None:
        self._data = load_dataset(source, **load_dataset_kwargs)

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index: int) -> Dict[str, List[int]]:
        sample = self._data[index]
        return self._prepare_sample(sample)

    def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
        tokens, labels = sample["input_ids"], sample["labels"]
        if self.max_seq_len is not None:
            tokens = truncate(tokens, self.max_seq_len - 1)
            labels = truncate(tokens, self.max_seq_len - 1)

        return {"tokens": tokens, "labels": labels}

then, you would simply replace the dataset argument in any config and the recipes should work OOTB. so if your RawDataset class was located in data/dataset.py, then the config would look like:

dataset:
  _component_: data.dataset.RawDataset
  source: path/to/my.json
  # any other arguments

This will work with the default collator in our recipes as well.

Also curious about your use case in general, if you don't mind sharing details :) Are you doing some offline custom tokenization, or want to define the labels in a customized way?

thusinh1969 commented 1 week ago

I also usually use pre-tokenized one as I have to manage tightly the max_length to make sure NOTHING is cutoff.

Steve