facebookresearch / GENRE

Autoregressive Entity Retrieval
Other
763 stars 102 forks source link

Threshold for minimum score for reliable results #63

Closed paulthemagno closed 2 years ago

paulthemagno commented 3 years ago

I'm trying to understand which can be a minimum threshold for reliable accuracy during the inference.

Looking at this example on the README:

import pickle
from genre.trie import Trie
from genre.fairseq_model import GENRE

# load the prefix tree (trie)
with open("../data/kilt_titles_trie_dict.pkl", "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

# load the model
model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()

# generate Wikipedia titles
model.sample(
    sentences=["Einstein was a [START_ENT] German [END_ENT] physicist."],
    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
[[{'text': 'Germany', 'score': tensor(-0.1856)},
  {'text': 'Germans', 'score': tensor(-0.5461)},
  {'text': 'German Empire', 'score': tensor(-2.1858)}]

all the 3 results can be considered good (probably German Empire is the worst but it's coherent with the searched string).

Computing inferences on many strings, I saw that "strange" entities (like invented names of characters from a fairy tale, that shouldn't have results on the web) are anyway found by GENRE, giving (obviously) wrong results. I thought that I could use low score values to cut them out and not return anything (as I desire).

Is there a minimum threshold I can use to consider only "pretty sure" results?. Which is the range of the scores? It seems a logarithmic scale, doesn't it? Is hypo["score"] in this line the average log probability of the target/reference token in each decoder step?

creatorrr commented 2 years ago

Is there any internal heuristic etc that the GENRE team uses for this?

nicola-decao commented 2 years ago

To make sure the output is reasonable, you should use constrained beam search, which you are. The score is not a log-probability as is the average log-probability, ie, the score is divided by the number of tokens to discourage short predictions. This is exactly the same as machine translation. A minimum threshold can be used but I did not make any experiments with it. Besides, we used GENRE for classification so we considered the class as the top-1 prediction no matter the score.