stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

[P1] Refactor ReftTrainer to save artifacts with the config #109

Open BryanWBear opened 3 weeks ago

BryanWBear commented 3 weeks ago

The issue is that ReftTrainer.save_model does not save the ReftConfig, only the intervention.

As a workaround, we can load the model from the checkpoint using the following code (by reinstantiating the config manually):

import pyreft
import pyvene as pv

reft_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype= torch.bfloat16, device_map="cuda")
reft_config = pyreft.ReftConfig(representations={
    "layer": 15, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=reft_model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pv.IntervenableModel(reft_config, reft_model)
reft_model.load_intervention('./tmp/checkpoint-78/intervenable_model')

device = 'cuda'
for k, v in reft_model.interventions.items():
    v[0].to(device)

Please let me know if I am missing something!

Thanks, Bryan

frankaging commented 3 weeks ago

@BryanWBear Yes! I am turning this ticket into a feature request, which i can work on it later. Thanks for bringing this up.

For now, to save your reft model, you can also try reft_model .save(<your_dir>) to save by using our own API instead of the trainer's API. I think this API will save the config as well as other artifacts.