stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
1.07k stars 90 forks source link

[P1] Intervention Locations more than Prefix and Suffix #122

Open comeandcode opened 1 month ago

comeandcode commented 1 month ago

Does the backbone pyvene support more than two intervention blocks on one layer? I met anaconda3/envs/reft_train/lib/python3.10/site-packages/pyvene/models/intervenable_base.py", line 1092, in _wait_for_forward_with_parallel_intervention unit_locations_base[ IndexError: list index out of range when I tried to add the third intervention block into one layer.

frankaging commented 1 month ago

@comeandcode Hey, what is your intervenable model config? And could you give an example of the unit_locations field you are feeding into the model forward pass? Thanks!

comeandcode commented 1 month ago

Thank you very much for your reply! Consider I am applying loreft to 3 layers: [3, 6, 9], and I create a list: layers = [3, 6, 9, 3, 6, 9, 3, 6, 9], and using a for loop to create representations like:reft_representations=[{ "component": component_mapping["language_model_layers"] % (int(index) - 1), "low_rank_dimension": low_rank_dimension, "intervention": pyreft.LoreftIntervention(embed_dim=model.model.config.hidden_size, low_rank_dimension=low_rank_dimension)} for index in layers]. And for the unit_locations passed in the forward, it's a 3d list like (assume batch_size =2 and the prompt length is the same for these two prompts ) [[[1, 2, 3],[1, 2, 3]], [[1, 2, 3],[1, 2, 3]], [[1, 2, 3],[1, 2, 3]], [[5, 6, 7],[5, 6, 7]], [[5, 6, 7],[5, 6, 7]], [[5, 6, 7],[5, 6, 7]], [[10, 11, 12],[10, 11, 12]], [[10, 11, 12],[10, 11, 12]], [[10, 11, 12],[10, 11, 12]]], so location 1, 2,3 for these 3 layers, location 5, 6, 7 for these 3 layers and location 10, 11, 12 for these 3 layers. (after permute(1, 0, 2).tolist()).

frankaging commented 1 month ago

@comeandcode Thanks. This should be a basic usecase of pyvene APIs.

Did you set up your call like:

        # run intervened forward pass
        unit_locations = None
        if "intervention_locations" in inputs:
            unit_locations={"sources->base": (
                None,
                inputs["intervention_locations"].permute(1, 0, 2).tolist()
            )}

        _, cf_outputs = intervenable(
            {
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"]
            },
            unit_locations=unit_locations,
            labels=inputs["labels"],
            subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None
        )

The unit_locations field needs to be set up as a dict.

comeandcode commented 1 month ago

Thank you for your reply! Yes, I followed the example you kindly provided in Pyreft/example/loreft, since using multiple intervention locations (3 or more) on the same layer is supported by pyvene, I will try to find out if there are any other mistakes and if so I will post them here! By the way, I found in the example/loreft, special tokens were skipped such as BOS or EOS right? Since I printed out the intervention locations and found they started from 1 but 0. So that only real prompts are intervened instead of affecting special tokens. I am not sure if I am correct on this. Thank you!

HongzhengYang commented 2 weeks ago

Hi, did you find out the reason ? Thanks!