stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

[P1] Loading ReFT for Llama3 model after fine-tuned with ReFT and LoRA #112

Closed Hamana0509 closed 1 week ago

Hamana0509 commented 1 week ago

I was trained and saved REFT LoRA, modules for the Llama3-8B-Instruct model. But when I load them from HuggingFace to inference, I get the following error:

TypeError                                 Traceback (most recent call last)
Cell In[1], line 11
      5 path = "Hamana0509/ReFT_Orpo_Llama3_8B_Instruct"
      7 model = transformers.AutoModelForCausalLM.from_pretrained(
      8     model_name, torch_dtype=torch.bfloat16, device_map=device
      9 )
---> 11 reft_model = pyreft.ReftModel.load(path, model, from_huggingface_hub=True)
     13 reft_model.set_device("cuda")

File /opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:26, in ReftModel.load(*args, **kwargs)
     24 @staticmethod
     25 def load(*args, **kwargs):
---> 26     model = pv.IntervenableModel.load(*args, **kwargs)
     27     return ReftModel._convert_to_reft_model(model)

File /opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:547, in IntervenableModel.load(load_directory, model, local_directory, from_huggingface_hub)
    543     casted_representations += [
    544         RepresentationConfig(*representation_opts)
    545     ]
    546 saving_config.representations = casted_representations
--> 547 intervenable = IntervenableModel(saving_config, model)
    549 # load binary files
    550 for i, (k, v) in enumerate(intervenable.interventions.items()):

File /opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:124, in IntervenableModel.__init__(self, config, model, **kwargs)
    122     all_metadata["embed_dim"] = component_dim
    123     all_metadata["use_fast"] = self.use_fast
--> 124     intervention = intervention_function(
    125         **all_metadata 
    126     )
    128 if representation.intervention_link_key in self._intervention_pointers:
    129     self._intervention_reverse_link[
    130         _key
    131     ] = f"link#{representation.intervention_link_key}"

File /opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:37, in LoreftIntervention.__init__(self, **kwargs)
     35 def __init__(self, **kwargs):
     36     super().__init__(**kwargs, keep_last_dim=True)
---> 37     rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
     38     self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder')
     39     self.learned_source = torch.nn.Linear(
     40         self.embed_dim, kwargs["low_rank_dimension"]).to(
     41         kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16)

File /opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:19, in LowRankRotateLayer.__init__(self, n, m, init_orth)
     17 super().__init__()
     18 # n > m
---> 19 self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
     20 if init_orth:
     21     torch.nn.init.orthogonal_(self.weight)

TypeError: empty() received an invalid combination of arguments - got (NoneType, int), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

What does the error above mean? And how to fix it?

Hamana0509 commented 1 week ago

@frankaging I have re-implemented ORPOTrainer similar to the DPOTrainer example, here is my code: https://colab.research.google.com/drive/1nKikg1c1-J5jGvlrqS995hNo6WqKmFXw?usp=sharing

frankaging commented 1 week ago

@Hamana0509 Thanks for sharing your notebook and raising your question!

I think the problem is that current model loading does not work well if you are trying to load a LoRA+ReFT model. To resolve this, you have to load weights manually by creating a random init LoRA+ReFT model, and load saved weights back.

For your trainer, feel free to open a PR to submit this! it would be a great contribution. Thanks!

Hamana0509 commented 1 week ago

@frankaging thank you