stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
947 stars 77 forks source link

[P1] Multiple Positions Intervention #96

Closed comeandcode closed 1 month ago

comeandcode commented 1 month ago

I am wondering if I want to apply reft to one layer in my model, and use prefix + suffix (multiple positions) like "f2+l2". And I am using the function get_intervention_locations() in make_multiple_position_supervised_data_module to create the input "intervention_locations" for the reft model. What number should the num_interventions be? Cuz I just got an empty list from the code below. I am wondering if num_interventions argument for function get_intervention_locations() should be * 2? Though I only have one layer of intervention, there are actually two interventions right? One for prefix and one for suffix? Thank you very much!

first_n, last_n = parse_positions("f2+l2")
element = [0, 1, 2, 3, 4, 5]
intervention_locations = get_intervention_locations(
            last_position=len(element), 
            first_n=first_n, 
            last_n=last_n,
            pad_mode="last",
            num_interventions=1,
            share_weights=False,
            )
# print(first_n)
# print(last_n)

print(intervention_locations)
frankaging commented 1 month ago

@comeandcode Yes, if share_weights=False, then the num_interventions is 2 * intervening layer number. For instance, you are intervening only on the last layer, then num_interventions=2.