intel / llm-on-ray

Pretrain, finetune and serve LLMs on Intel platforms with Ray
Apache License 2.0
94 stars 28 forks source link

A question about finetune dataset processing #234

Open KepingYan opened 3 months ago

KepingYan commented 3 months ago

DataCollatorForCompletionOnlyLM seems to be working not as expected, I'm not sure if this will affect the performance of finetuning.

We can know from this doc Fine Tuning Your Own ChatGPT-like Model that the purpose of function DataCollatorForCompletionOnlyLM in the dataset preprocessing is:

The class method encodes the response key new line into token IDs using the tokenizer, and searches for the start position of the response key in each example's label tensor. Once the start position of the response key is found, the label tensor is modified to mask out all tokens before the end of the response key. This is done by setting the label IDs for those tokens to -100, which is a special value that tells the PyTorch loss function to ignore them.

class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling):
    def torch_call(self, examples):
        batch = super().torch_call(examples)
        # The prompt ends with the response key plus a newline.  We encode this and then try to find it in the
        # sequence of tokens.  This should just be a single token.
        response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
        labels = batch["labels"].clone()
        for i in range(len(examples)):             # one batch, batch["labels"][i] gets every prompt label
            response_token_ids_start_idx = None
            for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
                response_token_ids_start_idx = idx
            if response_token_ids_start_idx is not None:
                response_token_ids_end_idx = response_token_ids_start_idx + 1
                # Make pytorch loss function ignore all tokens up through the end of the response key
                labels[i, :response_token_ids_end_idx] = -100
        batch["labels"] = labels

I think this is intend to mark every prompt from the beginning to the response key as a special mark, so that the model can focus more on training the response content. But I found this function cannot achieve this purpose in llm-on-ray now.

For example:

When parameter group is false

Prompt content:

  Below is an instruction that describes a task. Write a response that appropriately completes the request.

  ### Instruction:
  Which is a species of fish? Tope or Rope

  ### Response:

  ### End

Expected preprocessing results

  -100 -100 -100 -100 -100 -100 -100…………

  ### End

Actual preprocessing results

  -100 -100 -100 -100 -100 -100 -100…………
  -100 Instruction:
  Which is a species of fish? Tope or Rope

  ### Response:

  ### End

This is because only response_token_ids[0] (###) is compared in np.where(batch["labels"][i] == response_token_ids[0]), causing "### Instruction" to be discovered first instead of "### Response".

When parameter group is true

When group is true, multiple prompts will be combined into a new prompt. Here is an example of splicing two prompts. Prompt content:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Which is a species of fish? Tope or Rope

### Response:

### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?

### Response:
The name of the third daughter is Alice

### End

Expected preprocessing results

-100 -100 -100 -100 -100 -100 -100…………

### End
-100 -100 -100 -100 -100 -100 -100…………
The name of the third daughter is Alice

### End

Actual preprocessing results

-100 -100 -100 -100 -100 -100 -100…………
-100 Instruction:
Which is a species of fish? Tope or Rope

### Response:

### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?

### Response:
The name of the third daughter is Alice

### End

This is because it breaks when the first one is found, and the second prompt spliced in will not be processed.

KepingYan commented 3 months ago

@harborn @minmingzhu