zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.97k stars 238 forks source link

Autoregressive generation #404

Closed nlper-fighting closed 3 weeks ago

nlper-fighting commented 1 month ago

How to implement autoregressive generation instead of teacher-forcing in the inference phase?

pengzju commented 1 month ago

You can see code here: https://github.com/zjunlp/EasyEdit/blob/main/easyeditor/evaluate/evaluate_utils.py#L277

def verify_answer(model_answer, correct_answer):
    if type(correct_answer) is str:
        correct_answer = [[correct_answer]]
    for answer in correct_answer:
        if True not in [possible_answer in model_answer for possible_answer in answer]:
            return False
    return True

def answer_match(
    model,
    tok,
    prompt: str,
    target_new: str,
    device,
):
    inputs = tok.encode(prompt, return_tensors='pt').to(device)
    outputs = model.generate(inputs, temperature=0, max_new_tokens=30)
    predict = tok.decode(outputs[0], skip_special_tokens=True)

    return verify_answer(predict,target_new)

or you can use GPT-4 / string: exact match to evaluate

pseudo code

outputs = model.generate(inputs, temperature=0, max_new_tokens=30)
predict = tok.decode(outputs[0], skip_special_tokens=True)
metric = em(predict, target_new) or metric = gpt4_eval(predict, target_new) 
nlper-fighting commented 1 month ago

Thanks for your reply~ If I set vanilla_generation to True in the test_prediction_acc function, will it enable all models to implement autoregressive generation? def test_prediction_acc(model, tok, hparams, prompts, targets, device, locality=False, vanilla_generation=False):

pengzju commented 1 month ago

when vanilla_generation is set to True, the generated token sequence length matches target_new. This means that every token must match exactly for the metric to be 1, which makes the evaluation very strict.

A more reasonable approach would be to let the LLM generate a passage and then check if target_new appears within it, or to calculate recall (as per the code above).

nlper-fighting commented 1 month ago

In this case, the accuracy is either 0 or 1, not a token-by-token accuracy.

pengzju commented 1 month ago

There is no other way.

nlper-fighting commented 1 month ago

If locality is calculated in the same way, can it reach 100%?

pengzju commented 1 month ago

It's OK. The locality is defined as the post-edit model should not change the output of the irrelevant examples.

$\text{Loc}. = \frac{1}{T}\sum\limits{t=1}^{T} {1}(f{\Theta{T}}(x{\text{loc}}^{t}) = f{\Theta{0}}(x_{\text{loc}}^{t}))$

nlper-fighting commented 1 month ago

Thanks for your response! In the setting of autoregression, can the calculation process be expressed as the recall of the post-edit model for the loc_prompt divided by the recall of the pre-edit model?

pengzju commented 1 month ago

The past extensive literature consists of outputs before and after editing, and I don't understand what you mean by "recall". Moreover, if you wish to change the metrics, there's no need to consult the EasyEdit Team; you can simply follow your own discretion.

At least in my understanding, the concept of recall should not be equivalent to $\text{Loc}. = \frac{1}{T}\sum\limits{t=1}^{T} {1}(f{\Theta{T}}(x{\text{loc}}^{t}) = f{\Theta{0}}(x_{\text{loc}}^{t}))$

pengzju commented 1 month ago

The past extensive literature consists of outputs before and after editing, and I don't understand what you mean by "recall". Moreover, if you wish to change the metrics, there's no need to consult the EasyEdit Team; you can simply follow your own discretion.

At least in my understanding, the concept of recall should not be equivalent to Loc . = 1 T ∑ t = 1 T 1 ( f Θ T ( x loc t ) = f Θ 0 ( x loc t ) )

It's Accuracy (Acc.)

nlper-fighting commented 1 month ago

you mean token-by-token accuracy for locality in autoregression generation?

pengzju commented 1 month ago

Yes

zxlzr commented 4 weeks ago

hi buddy, do you have any further questions?