stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

forward() got an unexpected keyword argument 'unit_locations' #90

Closed xerkey closed 1 month ago

xerkey commented 1 month ago

I have an error on trainer.trian(). Plese help me!

Error

TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'unit_locations'

code

import pyreft
import torch
import transformers
import pandas as pd

prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

device='cpu'
model_id = "rinna/llama-3-youko-8b"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id, model_max_length=2048, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

# https://github.com/matsuvr/OjousamaTalkScriptDataset
df = pd.read_csv('./OjousamaTalkScriptDataset/ojousamatalkscript200.csv')
sample_df = df.sample(20)

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % row['prompt'] for _, row in sample_df.iterrows()], 
    [row['completion'] for _, row in sample_df.iterrows()])

reft_config = pyreft.ReftConfig(representations={
    "layer": 8, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()

training_args = transformers.TrainingArguments(
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 8,
    warmup_steps = 100,
    num_train_epochs = 1,
    learning_rate = 5e-4,
    # bf16 = True,
    logging_steps = 1,
    optim = "paged_adamw_32bit",
    weight_decay = 0.0,
    lr_scheduler_type = "cosine",
    output_dir = "outputs",
    report_to=[]
)

trainer = pyreft.ReftTrainerForCausalLM(model=model, tokenizer=tokenizer, args=training_args, **data_module)

_ = trainer.train()

Environment

pyreft                                  0.0.5
pyvene                                  0.1.1
torch                                   2.0.0
transformers                            4.39.3
frankaging commented 1 month ago

@xerkey hey, i think you are passing the original model instead of the reft_model to your trainer!

trainer = pyreft.ReftTrainerForCausalLM(
    model=model, # <----- here should be reft_model
    tokenizer=tokenizer, args=training_args, **data_module)
xerkey commented 1 month ago

@frankaging oh my god! it's simple mistake... Thank you very much!