stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

[P1] Model Compatibility #91

Closed SaBay89 closed 1 month ago

SaBay89 commented 1 month ago

I encounter an error when I attempt to use the model on GEMMA-2b from the Hugging Face Ecosystem. This error occurs regardless of whether I employ the model for classification or causal analysis tasks.

First of i download the model with the following arguments:

`from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained( "google/gemma-2b-it", num_labels=2,

quantization_config=bnb_config, device_map={"":0},

problem_type = "multi_label_classification",
torch_dtype=torch.bfloat16, 

) `

I then utilize arguments from this repository to wrap the pyreft library around my downloaded model.

`import pyreft reft_config = pyreft.ReftConfig(representations={ "layer": 15, "component": "block_output", "low_rank_dimension": 4, "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size, low_rank_dimension=4)})

reft_model = pyreft.get_reft_model(model, reft_config)`

I experiment with various layers and arguments of the pyreft library. However, irrespective of the layers or arguments I use, I consistently receive the following error message:

AttributeError: 'GemmaForSequenceClassification' object has no attribute ''

frankaging commented 1 month ago

@SaBay89 Hey, Gemma is not directly supported by pyvene at this point, so you need string access in the config:

# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 15, "component": "block_output",
    # alternatively, you can specify as string component access,
    # "component": "model.layers[0].output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

"model.layers[0].output" is the string access format. essentially, you need to print out the model architecture and point the component to the intervening module. the suffix .output indicating you are intervening on the module output.

let me know if this helps!

SaBay89 commented 1 month ago

It works Thank you !