Open zhuzihan728 opened 5 months ago
changed "output" to "answers" and it kind of fixed the problem but 0it [00:00, ?it/s]
...
Processed prompts: 0%| | 0/10 [00:00<?, ?it/s][A
Processed prompts: 10%|█ | 1/10 [00:00<00:03, 2.59it/s][A
Processed prompts: 70%|███████ | 7/10 [00:00<00:00, 13.11it/s][A Processed prompts: 100%|██████████| 10/10 [00:00<00:00, 15.28it/s]
19it [00:53, 2.81s/it]
Traceback (most recent call last):
File "run_short_form.py", line 378, in
i think the code isnt complete yet...
preds = []
prompts = []
golds = []
metric_results = []
scores = []
all_results = []
count = 0
for i, row in tqdm(enumerate(input_data)):
results = {}
prompt = PROMPT_DICT["prompt_no_input"].format_map(row)
_, evidences = process_data_evidences(row, top_n=args.ndocs)
pred, results, do_retrieve = generate(
prompt, evidences, max_new_tokens=args.max_new_tokens,)
if type(pred) is str and pred[0] == "#" or pred[0] == ":":
pred = pred[1:]
prompts.append(prompt)
preds.append(pred)
all_results.append(results)
if do_retrieve is True:
count += 1
if "answers" not in row and "answer" in row:
row["answers"] = [row["answer"]] if type(
row["answer"]) is str else row["answer"]
######################################################################
# 2024-05-22 fixed index outof range error
row["answers"] = [answer for answer in row["answers"] if answer != ""]
######################################################################
if args.metric == "accuracy":
#############################################################################################
# metric_result = accuracy(pred, row["output"])
metric_result = accuracy(pred, row["answers"])
############################################################################################## elif args.metric == "match": if "SUPPORTS" in pred: pred = "true" elif "REFUTES" in pred: pred = "false" metric_result = match(pred, row["answers"]) else: raise NotImplementedError
modified the code and works
preds = [] prompts = [] golds = [] metric_results = [] scores = [] all_results = [] count = 0 for i, row in tqdm(enumerate(input_data)): results = {} prompt = PROMPT_DICT["prompt_no_input"].format_map(row) _, evidences = process_data_evidences(row, top_n=args.ndocs) pred, results, do_retrieve = generate( prompt, evidences, max_new_tokens=args.max_new_tokens,) if type(pred) is str and pred[0] == "#" or pred[0] == ":": pred = pred[1:] prompts.append(prompt) preds.append(pred) all_results.append(results) if do_retrieve is True: count += 1 if "answers" not in row and "answer" in row: row["answers"] = [row["answer"]] if type( row["answer"]) is str else row["answer"] ###################################################################### # 2024-05-22 fixed index outof range error row["answers"] = [answer for answer in row["answers"] if answer != ""] ###################################################################### if args.metric == "accuracy":
############################################################################################# # 2024-05-22 fixed key error "output" doesnt exist # metric_result = accuracy(pred, row["output"]) metric_result = accuracy(pred, row["answers"]) ############################################################################################## elif args.metric == "match": if "SUPPORTS" in pred: pred = "true" elif "REFUTES" in pred: pred = "false" metric_result = match(pred, row["answers"]) else: raise NotImplementedError
modified the code and works
Thx for the reply :D I see now that the accuracy calculation is to check if the first letter of the prediction matches with the first letter of a gold answer. which only makes sense if the gold answer list is of length one cuz they are zipping a string with a list, and if it is for multiple-choice datasets?
And as you pointed out the list out of index error, the author seems mistakenly put empty strings ("") in some gold answer lists in the eval_data they provide in this link. Not sure why, but this definitely doesn't look right and only lifts up the final metric score if using the match
method in metrics.py
for the metric calculation.
@AkariAsai
The accuracy in metrics.py is defined as
While in run_short_form.py, acc is calculated per data instance:
where
pred
is some string, androw["output"]
is neither present in any short-form dataset, nor defined in your code.