allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

`train` and `val` splits are not disjoint for IMDB #47

Closed zhixuan-lin closed 1 year ago

zhixuan-lin commented 1 year ago

First, thanks for the great repo!

It seems that since the train and valid splits of IMDB are created with two separate calls to _get_datapool_by_split (which calls IMDB.prepare) and each call shuffles the data before sampling the split, the train and val splits will largely overlap (within each run). This seems highly problematic because results on the validation set will basically be invalid.

zhixuan-lin commented 1 year ago

A simple fix will be adding a seed parameter for shuffle in IMDB.prepare, and passing the same seed to IMDB.prepare when creating the train and val splits. A better approach will be to create train and val in one function call

class IMDBFixed(TextGenPool):
    """
    IMDB Dataset for sentiment continuation task

    The original IMDB class defined above has two issues:
    - train and validation sets can overlap
    - the splits are different for each run
    """
    @classmethod
    def prepare(cls, split: str, seed: int = 0):
        dataset = load_dataset("imdb")
        if split in ["train", "val"]:
            dataset_split = dataset["train"].shuffle(seed=seed)
            train_ratio = 0.8
            train_index = int(len(dataset_split) * train_ratio)
            dataset_split = dataset_split[:train_index] if split == "train" else dataset_split[train_index:]
        else:
            dataset_split = dataset[split].shuffle(seed=seed)
            dataset_split = dataset_split[:5000]

        samples = []
        for ix, text in enumerate(dataset_split["text"]):

            # here we consider 50% of tokens as prompt
            prompt_text = text.split(" ")
            prompt_text = " ".join(prompt_text[:int(len(prompt_text) * 0.5)])

            sample = Sample(id=f"{split}_{ix}",
                            prompt_or_input_text=prompt_text,
                            references=[text]
                            )
            samples.append(sample)
        pool_instance = cls(samples)
        return pool_instance
rajcscw commented 1 year ago

This is fixed in https://github.com/allenai/RL4LMs/pull/50

zhixuan-lin commented 1 year ago

awesome, thanks!