Closed SXxinxiaosong closed 2 weeks ago
你好 是的 目前因为easyedit的更新 部分方法的结果有提高,我们正在更新 会在近期挂出最新的结果
还有两个问题请问以下~:
1.目前是计算了所有的平均 2.Table4用的是ZsRE-test-all.json
好哒 谢谢!!
assert len(locality_inputs[locality_key]['prompt']) == len(locality_inputs[locality_key]['ground_truth']) \
== len(requests), print('One Edit instance needs one locality input.....')
但是这个是不是说每个edit只有一个locality
你好,knowedit的数据处理每一条locality是存储为一个列表哦。 你可以自己debug看到其中的输出。
好哒!谢谢~
您好,easyedit的prompts和locality_prompts似乎是要一致?请您帮忙看一下,谢谢~ 数据集:zsre-test 数据格式:``` { "subject": "Runaway Sunday", "target_new": "Motown", "prompt": "What was the record label of Runaway Sunday?", "ground_truth": [ "Virgin Records" ], "rephrase_prompt": "What was Runaway Sunday's record label?", "cond": "A&M Records >> Motown || What was the record label of Runaway Sunday?", "locality": { "Relation_Specificity": [ { "prompt": "The distribution format of Runaway Sunday is", "ground_truth": [ "music streaming" ] }, { "prompt": "Runaway Sunday distribution format", "ground_truth": [ "music streaming" ] } ] }, "portability": { "Reasoning": [ { "prompt": "Who founded the record label that signed Runaway Sunday?", "ground_truth": "Berry Gordy" } ] } },
部分代码:
portability_prompts = [portability_data['prompt'] for edit_data_ in edit_data for portability_category in edit_data_.get('portability', {}).values() for portability_data in portability_category]
portability_ans = [portability_data['ground_truth'] for edit_data_ in edit_data for portability_category in edit_data_.get('portability', {}).values() for portability_data in portability_category]
portability_inputs = {
'one_hop':{
'prompt': portability_prompts,
'ground_truth': portability_ans
},
}
locality_prompts = [relation['prompt'] for edit_data_ in edit_data if 'locality' in edit_data_ for relation in edit_data_['locality'].get('Relation_Specificity', [])]
locality_ans = [relation['ground_truth'][0] for edit_data_ in edit_data if 'locality' in edit_data_ for relation in edit_data_['locality'].get('Relation_Specificity', [])]
locality_inputs = {
'Relation_Specificity':{
'prompt': locality_prompts,
'ground_truth': locality_ans
},
}
hparams = MENDHyperParams.from_hparams('./hparams/MEND/llama-7b.yaml')
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model, _ = editor.edit(
prompts=prompts,
ground_truth=ground_truth,
target_new=target_new,
portability_inputs=portability_inputs,
locality_inputs=locality_inputs,
keep_original_weight=False,
sequential_edit=True,
)
报错:
Traceback (most recent call last):
File "/home/xsong/EasyEdit/edit.py", line 2859, in <module>
main()
File "/home/xsong/EasyEdit/edit.py", line 2789, in main
test_MEND_Llama()
File "/home/xsong/EasyEdit/edit.py", line 783, in test_MEND_Llama
metrics, edited_model, _ = editor.edit(
File "/home/xsong/EasyEdit/easyeditor/editors/editor.py", line 157, in edit
requests = _prepare_requests(prompts, target_new, ground_truth, rephrase_prompts, locality_inputs, portability_inputs, **kwargs)
File "/home/xsong/EasyEdit/easyeditor/editors/utils.py", line 115, in _prepare_requests
assert len(locality_inputs[locality_key]['prompt']) == len(locality_inputs[locality_key]['ground_truth']) \
== len(requests), print('One Edit instance needs one locality input.....')
AssertionError: None
您好,请问 Tabel 4 中 locality 和 portability 的评估各自是用的什么数据集呢? 是用的test中自带的数据吗?比如 zsre的locality,是用的ZsRE-test-all.json的locality吗?