huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.74k stars 1.07k forks source link

stf Example not working #1693

Open TATOAO opened 1 month ago

TATOAO commented 1 month ago

Current version: commit 84156f179f91f519e48185414391d040112f2d34 (HEAD -> main, origin/main, origin/HEAD) updated on Jun 3 2024

I tired to run the following script in example/scripts/stf.py:

# regular:
python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=64 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing

Error message:

Map:   0%|                                                                                                                                                                                                 | 0/9846 [00:00<?, ? examples/s]
Traceback (most recent call last):
  File "/Users/tatoaoliang/Downloads/Work/trl/examples/scripts/sft.py", line 137, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 360, in __init__
    train_dataset = self._prepare_dataset(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 506, in _prepare_dataset
    return self._prepare_non_packed_dataloader(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 574, in _prepare_non_packed_dataloader
    tokenized_dataset = dataset.map(
                        ^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3156, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3548, in _map_single
    batch = apply_function_on_filtered_inputs(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3417, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 545, in tokenize
    element[dataset_text_field] if not use_formatting_func else formatting_func(element),
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 271, in __getitem__
    value = self.data[key]
            ~~~~~~~~~^^^^^
KeyError: None

I check the codes, here is the original snippet of _prepare_non_packed_dataloader function in "trl/trainer/sft_trainer.py" 529 line:

    def _prepare_non_packed_dataloader(
        self,
        tokenizer,
        dataset,
        dataset_text_field,
        max_seq_length,
        formatting_func=None,
        add_special_tokens=True,
        remove_unused_columns=True,
    ):
        #### debugger told me that formatting_func is None and dataset_text_field is None
        use_formatting_func = formatting_func is not None and dataset_text_field is None
        self._dataset_sanity_checked = False

       #### so use_formatting_func is False  
        # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
        def tokenize(element):
            outputs = tokenizer(
                element[dataset_text_field] if not use_formatting_func else formatting_func(element),
                add_special_tokens=add_special_tokens,
                truncation=True,
                padding=False,
                max_length=max_seq_length,
                return_overflowing_tokens=False,
                return_length=False,
            )

So it seems that formatting_func should not be None.

it is defined in sft_trainer.py , line 313

formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)

and get_formatting_func_from_dataset is in trl/extras/dataset_formatting.py, line 60:

def get_formatting_func_from_dataset(
    dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
) -> Optional[Callable]:
    r"""
    Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
    - `ChatML` with [{"role": str, "content": str}]
    - `instruction` with [{"prompt": str, "completion": str}]

    Args:
        dataset (Dataset): User dataset
        tokenizer (AutoTokenizer): Tokenizer used for formatting

    Returns:
        Callable: Formatting function if the dataset format is supported else None
    """
    if isinstance(dataset, Dataset):
        if "messages" in dataset.features:
            if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
                logging.info("Formatting dataset with chatml format")
                return conversations_formatting_function(tokenizer, "messages")
        if "conversations" in dataset.features:
            if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
                logging.info("Formatting dataset with chatml format")
                return conversations_formatting_function(tokenizer, "conversations")
        elif dataset.features == FORMAT_MAPPING["instruction"]:
            logging.info("Formatting dataset with instruction format")
            return instructions_formatting_function(tokenizer)

    return None

But openassistant-guanaco dataset only has the feature "text", so it is incompatible.

https://huggingface.co/datasets/timdettmers/openassistant-guanaco?row=0

nikhil-tensorwave commented 4 weeks ago

I am also running into the same issue, what other package versions are you using? I am able to run some examples like the basic SFTTrainer from the README, but stf.py is not working

yzjiao commented 2 weeks ago

same problem here