zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.85k stars 223 forks source link

Edited weights in ROME #391

Open paulyoussef opened 1 day ago

paulyoussef commented 1 day ago

Hi, I'm following the tutorial under tutorial-notebooks to edit GPT2-XL with ROME, but the weights of the edited layer are not changing. And in rome_main.py keep_original_weight is not used . See example and output below:

prompts = ['Ray Charles, the',
            'Grant Hill is a professional',
            'The law in Ikaalinen declares the language'
            ]
ground_truth = ['piano',
                'basketball',
                'Finnish'
                ]
target_new = ['violin',
              'soccer',
              'Swedish'
              ]
subject = ['Ray Charles',
            'Grant Hill',
            'Ikaalinen'
            ]

hparams = ROMEHyperParams.from_hparams('./hparams/ROME/gpt2')
editor = BaseEditor.from_hparams(hparams)
print('weights before: ', editor.model.transformer.h[17].mlp.c_proj.weight.detach().cpu().numpy().sum())
metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject,
    keep_original_weight=False,
)
print('weights after: ', edited_model.transformer.h[17].mlp.c_proj.weight.detach().cpu().numpy().sum())

The output is (compare weights before to weights after):


weights before:  469.14468
100%|██████████| 3/3 [00:00<00:00,  9.02it/s]
  0%|          | 0/3 [00:00<?, ?it/s]
Executing ROME algorithm for the update: [Ray Charles, the] -> [ violin]
Cached context templates ['{}', 'The first thing I did. {}', 'The most common form of. {}', 'Therefore, if we are. {}', 'Therefore, I would suggest. {}', 'Because it is not a. {}', "Because I'm the only. {}", 'I am very proud to. {}', "I'm going to be. {}", "You'll be able to. {}", 'You can find the latest. {}', 'The most recent data from the U.S.. {}', 'The same day, I received a message from the. {}', 'Therefore I have come to the conclusion that the time. {}', 'Therefore, the question arises whether the government should be. {}', "Because it's the first day and you've just. {}", 'Because of this, I have decided to make this. {}', 'I think the most important thing that I can say. {}', 'I have to admit to being surprised that the ". {}', 'You\'re going to be fine." ". {}', "You'll see a lot more of that in the. {}"]
Computing left vector (u)...
Selected u projection object Ray Charles
Left vector shape: torch.Size([6400])
Computing right vector (v)
Lookup index found: 1 | Sentence: Ray Charles, the | Token:  Charles
Rewrite layer is 17
Tying optimization objective to 47
Recording initial value of v*
loss 10.598 = 10.598 + 0.0 + 0.0 avg prob of [ violin] 2.707893872866407e-05
loss 7.003 = 6.975 + 0.009 + 0.019 avg prob of [ violin] 0.000973085465375334
loss 4.71 = 4.652 + 0.026 + 0.032 avg prob of [ violin] 0.01042233593761921
loss 3.146 = 3.068 + 0.034 + 0.044 avg prob of [ violin] 0.050399281084537506
loss 1.919 = 1.819 + 0.046 + 0.055 avg prob of [ violin] 0.17390207946300507
loss 0.912 = 0.769 + 0.079 + 0.064 avg prob of [ violin] 0.4859868884086609
loss 0.445 = 0.263 + 0.109 + 0.073 avg prob of [ violin] 0.7775669097900391
loss 0.279 = 0.124 + 0.074 + 0.082 avg prob of [ violin] 0.886086642742157
loss 0.219 = 0.079 + 0.053 + 0.086 avg prob of [ violin] 0.9250359535217285
loss 0.19 = 0.056 + 0.048 + 0.086 avg prob of [ violin] 0.9467006921768188
loss 0.171 = 0.039 + 0.045 + 0.086 avg prob of [ violin] 0.9622268676757812
loss 0.157 = 0.028 + 0.043 + 0.086 avg prob of [ violin] 0.9730653762817383
loss 0.147 = 0.02 + 0.041 + 0.086 avg prob of [ violin] 0.9803466200828552
loss 0.141 = 0.015 + 0.039 + 0.086 avg prob of [ violin] 0.9851239919662476
loss 0.137 = 0.012 + 0.038 + 0.086 avg prob of [ violin] 0.9882499575614929
loss 0.134 = 0.01 + 0.038 + 0.086 avg prob of [ violin] 0.990328848361969
loss 0.132 = 0.008 + 0.037 + 0.086 avg prob of [ violin] 0.9917581081390381
loss 0.131 = 0.007 + 0.037 + 0.086 avg prob of [ violin] 0.9927856922149658
loss 0.13 = 0.006 + 0.037 + 0.086 avg prob of [ violin] 0.9935641884803772
2024-10-18 16:02:32,310 - easyeditor.editors.editor - INFO - 0 editing: Ray Charles, the -> violin  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Ray Charles, the', 'target_new': 'violin', 'ground_truth': 'piano', 'portability': {}, 'locality': {}, 'subject': 'Ray Charles'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
2024-10-18 16:02:32,310 - easyeditor.editors.editor - INFO - 0 editing: Ray Charles, the -> violin  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Ray Charles, the', 'target_new': 'violin', 'ground_truth': 'piano', 'portability': {}, 'locality': {}, 'subject': 'Ray Charles'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
10/18/2024 16:02:32 - INFO - easyeditor.editors.editor -   0 editing: Ray Charles, the -> violin  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Ray Charles, the', 'target_new': 'violin', 'ground_truth': 'piano', 'portability': {}, 'locality': {}, 'subject': 'Ray Charles'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
 33%|███▎      | 1/3 [00:05<00:10,  5.41s/it]
loss 0.129 = 0.006 + 0.037 + 0.086 avg prob of [ violin] 0.9941849708557129
Delta norm: 92.5128173828125
Change in target norm: 23.128206253051758 to 95.5408935546875 => 72.41268920898438
Division Factor: 13.210683822631836
Right vector norm: 7.0028791427612305
Right vector shape: torch.Size([1600])
Deltas successfully computed for ['transformer.h.17.mlp.c_proj.weight']
New weights successfully inserted into ['transformer.h.17.mlp.c_proj.weight']
Executing ROME algorithm for the update: [Grant Hill is a professional] -> [ soccer]
Computing left vector (u)...
Selected u projection object Grant Hill
Left vector shape: torch.Size([6400])
Computing right vector (v)
Lookup index found: 1 | Sentence: Grant Hill is a professional | Token:  Hill
Rewrite layer is 17
Tying optimization objective to 47
Recording initial value of v*
loss 5.895 = 5.895 + 0.0 + 0.0 avg prob of [ soccer] 0.003882689168676734
loss 4.043 = 4.019 + 0.006 + 0.018 avg prob of [ soccer] 0.022225558757781982
loss 2.465 = 2.428 + 0.008 + 0.028 avg prob of [ soccer] 0.11992382258176804
loss 1.217 = 1.171 + 0.009 + 0.036 avg prob of [ soccer] 0.37728649377822876
loss 0.695 = 0.639 + 0.012 + 0.044 avg prob of [ soccer] 0.557979941368103
loss 0.419 = 0.353 + 0.014 + 0.052 avg prob of [ soccer] 0.713262677192688
loss 0.283 = 0.208 + 0.016 + 0.058 avg prob of [ soccer] 0.8155384063720703
loss 0.218 = 0.135 + 0.018 + 0.065 avg prob of [ soccer] 0.8749680519104004
loss 0.186 = 0.096 + 0.02 + 0.071 avg prob of [ soccer] 0.9088634252548218
loss 0.171 = 0.073 + 0.021 + 0.076 avg prob of [ soccer] 0.9296196699142456
loss 0.162 = 0.058 + 0.023 + 0.081 avg prob of [ soccer] 0.9435552358627319
loss 0.156 = 0.048 + 0.023 + 0.085 avg prob of [ soccer] 0.9531516432762146
loss 0.15 = 0.042 + 0.023 + 0.085 avg prob of [ soccer] 0.9590563178062439
loss 0.144 = 0.037 + 0.023 + 0.085 avg prob of [ soccer] 0.964073896408081
loss 0.139 = 0.032 + 0.022 + 0.085 avg prob of [ soccer] 0.9683307409286499
loss 0.135 = 0.029 + 0.022 + 0.085 avg prob of [ soccer] 0.9719178676605225
loss 0.131 = 0.025 + 0.021 + 0.085 avg prob of [ soccer] 0.9749225378036499
loss 0.128 = 0.023 + 0.02 + 0.085 avg prob of [ soccer] 0.977435290813446
loss 0.125 = 0.021 + 0.02 + 0.085 avg prob of [ soccer] 0.9795466065406799
2024-10-18 16:02:36,765 - easyeditor.editors.editor - INFO - 1 editing: Grant Hill is a professional -> soccer  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Grant Hill is a professional', 'target_new': 'soccer', 'ground_truth': 'basketball', 'portability': {}, 'locality': {}, 'subject': 'Grant Hill'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
2024-10-18 16:02:36,765 - easyeditor.editors.editor - INFO - 1 editing: Grant Hill is a professional -> soccer  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Grant Hill is a professional', 'target_new': 'soccer', 'ground_truth': 'basketball', 'portability': {}, 'locality': {}, 'subject': 'Grant Hill'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
10/18/2024 16:02:36 - INFO - easyeditor.editors.editor -   1 editing: Grant Hill is a professional -> soccer  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Grant Hill is a professional', 'target_new': 'soccer', 'ground_truth': 'basketball', 'portability': {}, 'locality': {}, 'subject': 'Grant Hill'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
 67%|██████▋   | 2/3 [00:09<00:04,  4.85s/it]
loss 0.122 = 0.019 + 0.019 + 0.085 avg prob of [ soccer] 0.981339693069458
Delta norm: 94.6628646850586
Change in target norm: 23.66571617126465 to 98.17727661132812 => 74.51155853271484
Division Factor: 14.136369705200195
Right vector norm: 6.696405410766602
Right vector shape: torch.Size([1600])
Deltas successfully computed for ['transformer.h.17.mlp.c_proj.weight']
New weights successfully inserted into ['transformer.h.17.mlp.c_proj.weight']
Executing ROME algorithm for the update: [The law in Ikaalinen declares the language] -> [ Swedish]
Computing left vector (u)...
Selected u projection object Ikaalinen
Left vector shape: torch.Size([6400])
Computing right vector (v)
Lookup index found: 6 | Sentence: The law in Ikaalinen declares the language | Token: en
Rewrite layer is 17
Tying optimization objective to 47
Recording initial value of v*
loss 9.87 = 9.87 + 0.0 + 0.0 avg prob of [ Swedish] 6.724189734086394e-05
loss 8.168 = 8.132 + 0.006 + 0.03 avg prob of [ Swedish] 0.000428619678132236
loss 6.786 = 6.722 + 0.015 + 0.049 avg prob of [ Swedish] 0.0016584836412221193
loss 4.802 = 4.71 + 0.026 + 0.066 avg prob of [ Swedish] 0.011063705198466778
loss 2.514 = 2.397 + 0.034 + 0.083 avg prob of [ Swedish] 0.09913279861211777
loss 1.27 = 1.132 + 0.039 + 0.099 avg prob of [ Swedish] 0.33164602518081665
loss 0.764 = 0.611 + 0.044 + 0.109 avg prob of [ Swedish] 0.5488702058792114
loss 0.456 = 0.303 + 0.044 + 0.109 avg prob of [ Swedish] 0.7422630190849304
loss 0.289 = 0.138 + 0.043 + 0.109 avg prob of [ Swedish] 0.8728029131889343
loss 0.217 = 0.067 + 0.041 + 0.109 avg prob of [ Swedish] 0.9358943700790405
loss 0.187 = 0.037 + 0.041 + 0.109 avg prob of [ Swedish] 0.9634790420532227
loss 0.174 = 0.024 + 0.041 + 0.109 avg prob of [ Swedish] 0.9760932922363281
loss 0.167 = 0.018 + 0.041 + 0.109 avg prob of [ Swedish] 0.9824212789535522
loss 0.163 = 0.014 + 0.04 + 0.109 avg prob of [ Swedish] 0.9859662652015686
loss 0.16 = 0.012 + 0.039 + 0.109 avg prob of [ Swedish] 0.9882103204727173
loss 0.157 = 0.01 + 0.037 + 0.109 avg prob of [ Swedish] 0.9898483753204346
loss 0.154 = 0.009 + 0.036 + 0.109 avg prob of [ Swedish] 0.9912679195404053
loss 0.152 = 0.007 + 0.035 + 0.109 avg prob of [ Swedish] 0.9927050471305847
loss 0.149 = 0.006 + 0.034 + 0.109 avg prob of [ Swedish] 0.9942309260368347
2024-10-18 16:02:42,021 - easyeditor.editors.editor - INFO - 2 editing: The law in Ikaalinen declares the language -> Swedish  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'The law in Ikaalinen declares the language', 'target_new': 'Swedish', 'ground_truth': 'Finnish', 'portability': {}, 'locality': {}, 'subject': 'Ikaalinen'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
2024-10-18 16:02:42,021 - easyeditor.editors.editor - INFO - 2 editing: The law in Ikaalinen declares the language -> Swedish  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'The law in Ikaalinen declares the language', 'target_new': 'Swedish', 'ground_truth': 'Finnish', 'portability': {}, 'locality': {}, 'subject': 'Ikaalinen'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
10/18/2024 16:02:42 - INFO - easyeditor.editors.editor -   2 editing: The law in Ikaalinen declares the language -> Swedish  

 {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'The law in Ikaalinen declares the language', 'target_new': 'Swedish', 'ground_truth': 'Finnish', 'portability': {}, 'locality': {}, 'subject': 'Ikaalinen'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}
100%|██████████| 3/3 [00:15<00:00,  5.04s/it]
loss 0.147 = 0.004 + 0.033 + 0.109 avg prob of [ Swedish] 0.9957217574119568
Delta norm: 73.40950012207031
Change in target norm: 18.352375030517578 to 74.62543487548828 => 56.2730598449707
Division Factor: 13.669366836547852
Right vector norm: 5.370366096496582
Right vector shape: torch.Size([1600])
Deltas successfully computed for ['transformer.h.17.mlp.c_proj.weight']
New weights successfully inserted into ['transformer.h.17.mlp.c_proj.weight']
Metrics Summary:  {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 1.0}}
weights after:  469.14468
littlefive5 commented 19 hours ago

Hello, The keep_original_weight is deprecated in our new version, you can set sequential_edit as True. metrics, editedmodel, = editor.edit( prompts=prompts, ground_truth=ground_truth, target_new=target_new, subject=subject, sequential_edit=True, )

zxlzr commented 7 hours ago

hi, do you have any further issue?

paulyoussef commented 7 hours ago

Thanks for the quick response. Where do I find more information about the new version?