huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.23k stars 1.3k forks source link

add support for closed source model for Generalized Knowledge Distillation Trainer #2179

Closed imrankh46 closed 3 weeks ago

imrankh46 commented 1 month ago

Feature request

closed source model support for GKS, like openai gpt4-o and claude etc.

from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

Motivation

-

Your contribution

-

imrankh46 commented 1 month ago

@kashif @lewtun

kashif commented 1 month ago

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

August-murr commented 1 month ago

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

Anthropic API doesn't output any logits or logprobs and they have no plans to, and OpenAI only allows a max of 20 logprobs. It seems like they really don't want you to distill. OpenAI recently announced a distillation service, but it's only for their own models and not open source.