zjunlp / EasyEdit

[知识编辑] [ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.63k stars 200 forks source link

locality and portability evaluation #324

Closed SXxinxiaosong closed 2 weeks ago

SXxinxiaosong commented 2 weeks ago

您好,请问 Tabel 4 中 locality 和 portability 的评估各自是用的什么数据集呢? 是用的test中自带的数据吗?比如 zsre的locality,是用的ZsRE-test-all.json的locality吗?

littlefive5 commented 2 weeks ago

你好 是的 目前因为easyedit的更新 部分方法的结果有提高,我们正在更新 会在近期挂出最新的结果

SXxinxiaosong commented 2 weeks ago

还有两个问题请问以下~:

  1. 数据集中一条edit对应多个locality的话,是随机选择一个吗? 2.. portability 是用的提到的额外的数据集zsre_portability_gpt4.json吗,还是ZsRE-test-all.json中的呢?
littlefive5 commented 2 weeks ago

1.目前是计算了所有的平均 2.Table4用的是ZsRE-test-all.json

SXxinxiaosong commented 2 weeks ago

好哒 谢谢!!

SXxinxiaosong commented 2 weeks ago
assert len(locality_inputs[locality_key]['prompt']) == len(locality_inputs[locality_key]['ground_truth']) \
            == len(requests), print('One Edit instance needs one locality input.....')

但是这个是不是说每个edit只有一个locality

littlefive5 commented 2 weeks ago

你好,knowedit的数据处理每一条locality是存储为一个列表哦。 你可以自己debug看到其中的输出。

SXxinxiaosong commented 2 weeks ago

好哒!谢谢~

SXxinxiaosong commented 2 weeks ago

您好,easyedit的prompts和locality_prompts似乎是要一致?请您帮忙看一下,谢谢~ 数据集:zsre-test 数据格式:``` { "subject": "Runaway Sunday", "target_new": "Motown", "prompt": "What was the record label of Runaway Sunday?", "ground_truth": [ "Virgin Records" ], "rephrase_prompt": "What was Runaway Sunday's record label?", "cond": "A&M Records >> Motown || What was the record label of Runaway Sunday?", "locality": { "Relation_Specificity": [ { "prompt": "The distribution format of Runaway Sunday is", "ground_truth": [ "music streaming" ] }, { "prompt": "Runaway Sunday distribution format", "ground_truth": [ "music streaming" ] } ] }, "portability": { "Reasoning": [ { "prompt": "Who founded the record label that signed Runaway Sunday?", "ground_truth": "Berry Gordy" } ] } },

部分代码:
portability_prompts = [portability_data['prompt'] for edit_data_ in edit_data for portability_category in edit_data_.get('portability', {}).values() for portability_data in portability_category]

portability_ans = [portability_data['ground_truth'] for edit_data_ in edit_data for portability_category in edit_data_.get('portability', {}).values() for portability_data in portability_category]

portability_inputs = {
        'one_hop':{
            'prompt': portability_prompts,
            'ground_truth': portability_ans
        },
    }
locality_prompts = [relation['prompt'] for edit_data_ in edit_data if 'locality' in edit_data_ for relation in edit_data_['locality'].get('Relation_Specificity', [])]
locality_ans = [relation['ground_truth'][0] for edit_data_ in edit_data if 'locality' in edit_data_ for relation in edit_data_['locality'].get('Relation_Specificity', [])]
    locality_inputs = {
        'Relation_Specificity':{
            'prompt': locality_prompts,
            'ground_truth': locality_ans
        },
    }
hparams = MENDHyperParams.from_hparams('./hparams/MEND/llama-7b.yaml')
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        ground_truth=ground_truth,
        target_new=target_new,
        portability_inputs=portability_inputs,
        locality_inputs=locality_inputs,
        keep_original_weight=False,
        sequential_edit=True,
    )

报错:

Traceback (most recent call last):
  File "/home/xsong/EasyEdit/edit.py", line 2859, in <module>
    main()
  File "/home/xsong/EasyEdit/edit.py", line 2789, in main
    test_MEND_Llama()
  File "/home/xsong/EasyEdit/edit.py", line 783, in test_MEND_Llama
    metrics, edited_model, _ = editor.edit(
  File "/home/xsong/EasyEdit/easyeditor/editors/editor.py", line 157, in edit
    requests = _prepare_requests(prompts, target_new, ground_truth, rephrase_prompts, locality_inputs, portability_inputs, **kwargs)
  File "/home/xsong/EasyEdit/easyeditor/editors/utils.py", line 115, in _prepare_requests
   assert len(locality_inputs[locality_key]['prompt']) == len(locality_inputs[locality_key]['ground_truth']) \
            == len(requests), print('One Edit instance needs one locality input.....')
AssertionError: None