Closed thistleknot closed 1 month ago
solution: searched for .load saw this post https://github.com/stanfordnlp/pyreft/issues/45
saw this comment You could always load back the model manually: you could first construct the original config and create a reft model; and manually load back saved weights by using torch.load or torch.save for saving. All the interventions are saved in reft_model.interventions as key-value pairs.
#!git clone https://huggingface.co/LaferriereJC/Phi-3-mini-4k-instruct-FOL-pyreft
from transformers import AutoModelForCausalLM
import torch
import pyreft
import os
model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
# Define the PyReFT configuration
layers = range(model.config.num_hidden_layers)
representations = [{
"component": f"model.layers[{l}].output",
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size,
low_rank_dimension=16
)
} for l in layers]
reft_config = pyreft.ReftConfig(representations=representations)
# Initialize the PyReFT model
reft_model = pyreft.get_reft_model(model, reft_config)
# Load the saved PyReFT model
local_directory = "./Phi-3-mini-4k-instruct-FOL-pyreft"
interventions = {}
for l in layers:
component = f"model.layers[{l}].output"
file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
if os.path.exists(file_path):
with open(file_path, "rb") as f:
adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
interventions[adjusted_key] = torch.load(f)
# Apply the loaded weights to the model
for component, state_dict in interventions.items():
if component in reft_model.interventions:
reft_model.interventions[component][0].load_state_dict(state_dict)
else:
print(f"Key mismatch: {component} not found in reft_model.interventions")
# Set the device to CUDA
reft_model.set_device("cuda")
# Verify the model
reft_model.print_trainable_parameters()
trying to load a saved pyreft
how model was init'd
saving
loading
error