huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.13k stars 1.28k forks source link

DPO does not work for FIM task with non-instruct model #2382

Open AML14 opened 13 hours ago

AML14 commented 13 hours ago

System Info

Information

Tasks

Reproduction

The script that I'm using is a slightly modified version of the official dpo.py example script. The task that I'm trying to train the model on is FIM, not a chat-related task, therefore I'm using a base model, not an instruct one.

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()

    ################
    # Model & Tokenizer
    ###################
    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
    )
    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForCausalLM.from_pretrained(
            model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
        )
    else:
        ref_model = None
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    ################
    # Dataset
    ################
    dataset = load_dataset("json", data_files=script_args.dataset_name)["train"]

    ##########
    # Training
    ################
    trainer = DPOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    trainer.train()

    # Save
    trainer.save_model(training_args.output_dir)

Command:

accelerate launch --config_file acc_config_1.yaml dpo.py \
    --dataset_name dpo_prepared_predictions.json \
    --model_name_or_path Qwen/Qwen2.5-Coder-0.5B \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing \
    --logging_steps 25 \
    --eval_strategy no \
    --output_dir Qwen2.5-Coder-0.5B-DPO \
    --no_remove_unused_columns

Instance examples from the dataset:

{"prompt":"<|fim_prefix|>protected void doStop()throws Exception {\n    super.doStop();\n    if(muc != null) {\n        <|fim_suffix|>\n        \n    }\n}\n<|fim_middle|>","chosen":"muc.leave();\n        muc = null;<|endoftext|>","rejected":"muc.stop();<|endoftext|>"}
{"prompt":"<|fim_prefix|>public synchronized void addInterceptedNode(ProcessorType node) {\n    if(routeList == null) {\n        routeList = new ArrayList < <|fim_suffix|>\n    \n}\n<|fim_middle|>","chosen":"ProcessorType > ();\n        \n    }\n    routeList.add(node);<|endoftext|>","rejected":">();\n    }\n    routeList.add(node);\n}\n\npublic synchronized void removeInterceptedNode(ProcessorType node) {\n    if(routeList != null) {\n        routeList.remove(node);\n    }\n}\n\npublic synchronized List < ProcessorType > getInterceptedNodes() {\n    return routeList;<|endoftext|>"}

Outputs:

The training works fine, but the model is completely broken after training with DPO. It doesn't even generate completions that are syntactically valid, while the pre-trained model or one that has been fine-tuned with the same data works just fine. For example, given the previous first input, the model generates the following output:

// Input:
<|fim_prefix|>protected void doStop()throws Exception {
    super.doStop();
    if(muc != null) {
        <|fim_suffix|>

    }
}
<|fim_middle|>

// Output:
   <|fim_middle|>   <|fim_middle|>   <|fim_middle|>   <|fim_middle|>   <|fim_middle|><|fim_middle|>requests().removeListener(new RequestListener()<|endoftext|>

And given the second input, the model generates the following output (note that replacing <|fim_middle|> in the input with the output does not produce compilable code):

// Input:
<|fim_prefix|>public synchronized void addInterceptedNode(ProcessorType node) {
    if(routeList == null) {
        routeList = new ArrayList < <|fim_suffix|>

}
<|fim_middle|>

// Output:
RouteListImpl();<|endoftext|>

Expected behavior

The resulting model trained with DPO should, at least, produce compilable code, without extra special tokens in the output. Needless to say, it should also improve performance.

Checklist

AML14 commented 10 hours ago

Update: DPO doesn't even work with a code completion task (i.e., neither the input nor output include FIM special tokens) with the base model. As an example, here is the output generated by Qwen/Qwen2.5-Coder-0.5B for the following input:

// Input:
protected RouteBuilder createRouteBuilder()throws Exception {
    return new RouteBuilder() {

// Output:
        @Override
        public void configure() throws Exception {
            from("direct:hello")
                .to("mock:hello");
        }
    };
}<|endoftext|>

And here is the output of the same model after having applied DPO with about 3000 instances, where the prompt is the input and the chosen/rejected are correct/wrong completions:

// Input:
protected RouteBuilder createRouteBuilder()throws Exception {
    return new RouteBuilder() {

// Output:
public void configure() throws Exception {
<|fim_middle|>
<|fim_middle|>
<|fim_middle|><|endoftext|>

The model is completely broken after applying DPO.