stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
1.12k stars 93 forks source link

Getting issue while loading Phi3 in reft_model #80

Closed atharvapatiil closed 4 months ago

atharvapatiil commented 4 months ago
import torch, transformers, pyreft 

model_name = 'microsoft/Phi-3-mini-4k-instruct'
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map='cuda', 
    cache_dir='./workspace', token='', trust_remote_code=True
)

reft_config = pyreft.ReftConfig(
    representations={
        "layer":15,
        "component":"block_output", 
        "low_rank_dimension":4,
        "intervention":pyreft.LoreftIntervention(
            embed_dim=model.config.hidden_size, low_rank_dimension=4
        ) 
    }
)

reft_model = pyreft.get_reft_model(model, reft_config)

AttributeError Traceback (most recent call last) Cell In[42], line 1 ----> 1 reft_model = pyreft.get_reft_model(model, reft_config)

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/pyreft/utils.py:35, in get_reft_model(model, reft_config, set_device) 31 def get_reft_model(model, reft_config, set_device=True): 32 """ 33 Create an instance of ReFT model. 34 """ ---> 35 reft_model = ReftModel(reft_config, model) 36 if set_device: 37 reft_model.set_device(model.device)

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/pyreft/reft_model.py:14, in ReftModel.init(self, config, model, kwargs) 13 def init(self, config, model, kwargs): ---> 14 super().init(config, model, **kwargs)

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:149, in IntervenableModel.init(self, config, model, **kwargs) 143 if isinstance( 144 intervention, 145 CollectIntervention 146 ): 147 self.return_collect_activations = True --> 149 module_hook = get_module_hook( 150 model, representation 151 ) 152 self.representations[_key] = representation 153 self.interventions[_key] = (intervention, module_hook)

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/pyvene/models/modeling_utils.py:159, in get_module_hook(model, representation) 156 elif representation.component.split(".")[-1] == "output": 157 hook_type = CONST_OUTPUT_HOOK --> 159 module = getattr_for_torch_module(model, parameter_name) 160 module_hook = getattr(module, hook_type) 162 return module_hook

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/pyvene/models/modeling_utils.py:93, in getattr_for_torch_module(model, parameter_name) 89 current_module = getattr(current_module, param.split("[")[0])[ 90 int(param.split("[")[-1].strip("]")) 91 ] 92 else: ---> 93 current_module = getattr(current_module, param) 94 return current_module

File ~/anaconda3/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py:1688, in Module.getattr(self, name) 1686 if name in modules: 1687 return modules[name] -> 1688 raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")

AttributeError: 'Phi3ForCausalLM' object has no attribute ''

kw2828 commented 4 months ago
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#)` in __getattr__(self, name)
   1686             if name in modules:
   1687                 return modules[name]
-> 1688         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
   1689 
   1690     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'Phi3ForCausalLM' object has no attribute ''

Can confirm too

frankaging commented 4 months ago

@atharvapatiil @kw2828 thanks for bringing up the issue!

the reason is that pyvene does not naturally support phi3 at this point. Please use string access to the intervening module instead. See this as an example: https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/train.py#L286

this should allow you to intervene on any torch module.

one downside of this is streamlined ReFT model saving and loading --- since pyvene does not natively support this, you have to come up with your own saving and loading scripts.

frankaging commented 4 months ago

for example, try this:

reft_config = pyreft.ReftConfig(
    representations={
        "component":"model.layers[15].output",   # indicate the component (including the layer) instead
        "low_rank_dimension":4,
        "intervention":pyreft.LoreftIntervention(
            embed_dim=model.config.hidden_size, low_rank_dimension=4
        ) 
    }
)
willieseun commented 4 months ago

First adding this helped torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False)

and then using the below as the string access made it work

reft_config = ReftConfig(representations={
    "component": "model.layers[15].mlp.output", # string access to the model component
    "intervention": LoreftIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1)})

For more insights you can check out my notebook here https://www.kaggle.com/code/williamalabi/finetuning-gemma-with-reft Please leave an upvote if you found it hepful:-)

kw2828 commented 4 months ago

@atharvapatiil @kw2828 thanks for bringing up the issue!

the reason is that pyvene does not naturally support phi3 at this point. Please use string access to the intervening module instead. See this as an example: https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/train.py#L286

this should allow you to intervene on any torch module.

one downside of this is streamlined ReFT model saving and loading --- since pyvene does not natively support this, you have to come up with your own saving and loading scripts.

@frankaging Thanks for your help, this worked

frankaging commented 4 months ago

Thanks for all the inputs! I am closing this issue for now, and feel free to reopen if you have other questions!

bernerprzemek commented 4 months ago

I have one question, how to load back phi3 model? After ReftModel.load :

Traceback (most recent call last):
  File "/gpfs/scratchfs01/site/u/bernerp3/proj/ReFT/eval.py", line 82, in <module>
    reft_model = pyreft.ReftModel.load(
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bernerp3/scratch/conda/envs/dspeed/lib/python3.12/site-packages/pyreft/reft_model.py", line 26, in load
    model = pv.IntervenableModel.load(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bernerp3/scratch/conda/envs/dspeed/lib/python3.12/site-packages/pyvene/models/intervenable_base.py", line 546, in load
    intervenable = IntervenableModel(saving_config, model)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bernerp3/scratch/conda/envs/dspeed/lib/python3.12/site-packages/pyvene/models/intervenable_base.py", line 116, in __init__
    component_dim = get_dimension_by_component(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bernerp3/scratch/conda/envs/dspeed/lib/python3.12/site-packages/pyvene/models/modeling_utils.py", line 100, in get_dimension_by_component
    if component not in type_to_dimension_mapping[model_type]:
                        ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: <class 'transformers.models.phi3.modeling_phi3.Phi3ForCausalLM'>

As I understand in pyvene there is no proper mapping, is it possible to some how solve this?

frankaging commented 4 months ago

@bernerprzemek hey! thanks for raising this.

unfortunately, this is a known issue -- pyvene currently does not support to load back model automatically that is not supported natively.

but here are a couple of ways to solve this:

  1. You could take a look at the pyvene tutorial of adding a new model in pyvene without changing the library. Search the keyword: "Add New Model Type".
  2. You could always load back the model manually: you could first construct the original config and create a reft model; and manually load back saved weights by using torch.load or torch.save for saving. All the interventions are saved in reft_model.interventions as key-value pairs.
bernerprzemek commented 4 months ago

Thanks a lot of I used second way. And wow whole ReFT idea is amazing.