zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.74k stars 210 forks source link

Suggestions for PMET Method on Llama #247

Closed Lut-hub closed 3 months ago

Lut-hub commented 4 months ago

Hello!

I came across a tiny issue while editing Llama2 using the PMET method, and I thought it might be worth mentioning: At line 44 of the file easyeditor/models/pmet/compute_zs.py:

target_ids = tok(request["target_new"], return_tensors="pt").to("cuda")["input_ids"][0]

The LlamaTokenizer adds an additional "<s>" token to request["target_new"]. This will result in appending an additional "<s>" token to the last of query during subsequent processes, making it difficult for PMET to optimize zs effectively.

For example, when we edit the object of "What is the native language of Christiane Cohendy?" to "German", the result is:

Lookup index found: 12 | Sentence: What is the native language of Christiane Cohendy?<s> | Token: y
Rewrite layer is 8
Tying optimization objective to 31
Recording initial value of v* in attn
Recording initial value of v* in mlp
loss 11.094 = 11.094 + 0.0 + 0.0 avg prob of [ German] 2.2592015739064664e-05
loss 10.079 = 10.078 + 0.0 + 0.001 avg prob of [ German] 6.582784408237785e-05
loss 9.424 = 9.423 + 0.0 + 0.001 avg prob of [ German] 0.0001342837349511683
loss 7.538 = 7.538 + 0.0 + 0.001 avg prob of [ German] 0.0012688480783253908
...

However, for MEMIT, it should be:

Lookup index found: 12 | Sentence: What is the native language of Christiane Cohendy? | Token: y
Rewrite layer is 8
Tying optimization objective to 31
Recording initial value of v*
loss 6.46 = 6.46 + 0.0 + 0.0 avg prob of [ German] 0.002235208638012409
loss 4.531 = 4.399 + 0.099 + 0.033 avg prob of [ German] 0.013420548290014267
loss 3.512 = 3.249 + 0.231 + 0.033 avg prob of [ German] 0.04390619695186615
loss 2.939 = 2.666 + 0.24 + 0.033 avg prob of [ German] 0.07955510914325714
...

For the MEMIT method, there's the following code at line 43 of the file easyeditor/models/memit/compute_z.py:

if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
        target_ids = target_ids[1:]

So, it would be better to add the aforementioned code at line 47 of the file easyeditor/models/pmet/compute_zs.py. 😊

XeeKee commented 4 months ago

Thank you very much for your interest in EasyEdit. We apologize for our limited availability as we are currently busy with the nips submission deadline. We will focus on optimization after the deadline is over.

pengzju commented 3 months ago

Thank you for your suggestion. I will modify the entire code to use tok.encode(xx, add_special_tokens=False) to avoid adding unnecessary tokens.

Lut-hub commented 3 months ago

Thanks for your reply 😊