stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

transformers_modules.microsoft.Phi-3-mini-4k-instruct.d269012bea6fbe38ce7752c8940fea010eea3383.modeling_phi3.Phi3ForCausalLM #101

Closed thistleknot closed 1 month ago

thistleknot commented 1 month ago

trying to load a saved pyreft

how model was init'd

import pyreft
import os

# Define the PyReFT configuration
layers = range(model.config.num_hidden_layers)
representations = [{
    "component": f"model.layers[{l}].output",
    "intervention": pyreft.LoreftIntervention(
        embed_dim=model.config.hidden_size, 
        low_rank_dimension=4
    )
} for l in layers]

reft_config = pyreft.ReftConfig(representations=representations)

# Initialize the PyReFT model
reft_model = pyreft.get_reft_model(model, reft_config)

# Load the saved PyReFT model
local_directory = "./Phi-3-mini-4k-instruct-FOL-pyreft"
interventions = {}
for l in layers:
    component = f"model.layers[{l}].output"
    file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
    if os.path.exists(file_path):
        with open(file_path, "rb") as f:
            adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
            interventions[adjusted_key] = torch.load(f)

# Integrate the interventions into the model
for component, state_dict in interventions.items():
    if component in reft_model.interventions:
        reft_model.interventions[component][0].load_state_dict(state_dict)
    else:
        print(f"Key mismatch: {component} not found in reft_model.interventions")

# Set the device to CUDA
reft_model.set_device("cuda")

# Verify the model
reft_model.print_trainable_parameters()

saving

reft_model.set_device("cpu") # send back to cpu before saving.
reft_model.save(
    save_directory="./fol",
    save_to_hf_hub=True, 
    hf_repo_name="LaferriereJC/Phi-3-mini-4k-instruct-FOL-pyreft"
)

loading

reft_model.set_device("cuda") # send back to cpu before saving.
reft_model = pyreft.ReftModel.load(
    "./fol", model
)

error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[92], line 2
      1 reft_model.set_device("cuda") # send back to cpu before saving.
----> 2 reft_model = pyreft.ReftModel.load(
      3     "[./fol](http://192.168.3.17:8888/lab/tree/fol)", model
      4 )

File /home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/pyreft/reft_model.py:26, in ReftModel.load(*args, **kwargs)
     24 @staticmethod
     25 def load(*args, **kwargs):
---> 26     model = pv.IntervenableModel.load(*args, **kwargs)
     27     return ReftModel._convert_to_reft_model(model)

File /home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:546, in IntervenableModel.load(load_directory, model, local_directory, from_huggingface_hub)
    542     casted_representations += [
    543         RepresentationConfig(*representation_opts)
    544     ]
    545 saving_config.representations = casted_representations
--> 546 intervenable = IntervenableModel(saving_config, model)
    548 # load binary files
    549 for i, (k, v) in enumerate(intervenable.interventions.items()):

File /home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:116, in IntervenableModel.__init__(self, config, model, **kwargs)
    110 intervention_function = (
    111     intervention_type
    112     if type(intervention_type) != list
    113     else intervention_type[i]
    114 )
    115 all_metadata = representation._asdict()
--> 116 component_dim = get_dimension_by_component(
    117     get_internal_model_type(model), model.config, 
    118     representation.component
    119 )
    120 if component_dim is not None:
    121     component_dim *= int(representation.max_number_of_units)

File /home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/pyvene/models/modeling_utils.py:100, in get_dimension_by_component(model_type, model_config, component)
     97 def get_dimension_by_component(model_type, model_config, component) -> int:
     98     """Based on the representation, get the aligning dimension size."""
--> 100     if component not in type_to_dimension_mapping[model_type]:
    101         return None
    103     dimension_proposals = type_to_dimension_mapping[model_type][component]

KeyError: <class 'transformers_modules.microsoft.Phi-3-mini-4k-instruct.d269012bea6fbe38ce7752c8940fea010eea3383.modeling_phi3.Phi3ForCausalLM'>
thistleknot commented 1 month ago

solution: searched for .load saw this post https://github.com/stanfordnlp/pyreft/issues/45

saw this comment You could always load back the model manually: you could first construct the original config and create a reft model; and manually load back saved weights by using torch.load or torch.save for saving. All the interventions are saved in reft_model.interventions as key-value pairs.

#!git clone https://huggingface.co/LaferriereJC/Phi-3-mini-4k-instruct-FOL-pyreft
from transformers import AutoModelForCausalLM
import torch
import pyreft
import os

model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path, 
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True, 
    device_map="auto"
)

# Define the PyReFT configuration
layers = range(model.config.num_hidden_layers)
representations = [{
    "component": f"model.layers[{l}].output",
    "intervention": pyreft.LoreftIntervention(
        embed_dim=model.config.hidden_size, 
        low_rank_dimension=16
    )
} for l in layers]

reft_config = pyreft.ReftConfig(representations=representations)

# Initialize the PyReFT model
reft_model = pyreft.get_reft_model(model, reft_config)

# Load the saved PyReFT model
local_directory = "./Phi-3-mini-4k-instruct-FOL-pyreft"
interventions = {}
for l in layers:
    component = f"model.layers[{l}].output"
    file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
    if os.path.exists(file_path):
        with open(file_path, "rb") as f:
            adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
            interventions[adjusted_key] = torch.load(f)

# Apply the loaded weights to the model
for component, state_dict in interventions.items():
    if component in reft_model.interventions:
        reft_model.interventions[component][0].load_state_dict(state_dict)
    else:
        print(f"Key mismatch: {component} not found in reft_model.interventions")

# Set the device to CUDA
reft_model.set_device("cuda")

# Verify the model
reft_model.print_trainable_parameters()