facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.26k stars 6.39k forks source link

Forced decoding & decoder score #4669

Open zouharvi opened 2 years ago

zouharvi commented 2 years ago

❓ Forced decoding & decoder score

I'm using the hub interface. It is possible to get the decoder scores of the just generated hypothesis (generated by the model):

model = torch.hub.load(...)
sent_src_enc = model.encode(sent_src)
sent_tgt_enc = model.generate(sent_src_enc, nbest=1)[0]
sent_tgt_score = sent_tgt_enc["score"].item()

Assuming that I already have the source and hypothesis text from some other source, how would I force the decoder to decode the target text and return the logprob? I know of the existence of SequenceScorer and --score-reference but was unable to use them with the hub interface:

scorer = SequenceScorer(model.tgt_dict)
sent_src_enc = model.encode(sent_src)
sent_tgt_enc = model.encode(sent_tgt)
scorer.generate(model.models(), {"net_input": sent_src_enc, "target": sent_tgt_enc}) # ERROR

What's your environment?

gmryu commented 2 years ago

What is the error? Is it like wrong structure of input {"net_input": sent_src_enc, "target": sent_tgt_enc}?

If it is a transformer with a normal translation task, it uses LanguagePairDataset. From its definition .py, you can see its structure It creates a prev_output_tokens from target.

In addition, transformer's forward tells it does not look up target but prev_output_tokens. So I guess you may change the key target to prev_output_tokens.

One last thing, prev_output_tokens is already a tensor, not a dictionary.(see decoder's forward) You better check out what model.encode(sent_tgt) is and take what you need from it. You may use def merge if it helps.

If your model is not fairseq transformer, or you are using a different dataset. The basic flow it the same, find what data is expected and make a proper batch yourself. Good luck!


You can find a sorta relative example in fairseq_cli/eval_lm.py 's eval_lm You may call this method if loss entropy is what you want finally.

erip commented 1 year ago

You can perform forced decoding with the following script:

#!/usr/bin/env python3

import torch
from fairseq.sequence_scorer import SequenceScorer
from fairseq.models.transformer import TransformerModel

if __name__ == "__main__":
    sent_src = "Hello world!"
    sent_tgt = "Hallo Welt!"

    model = TransformerModel.from_pretrained(...)
    scorer = SequenceScorer(model.tgt_dict)

    enc_src = model.encode(sent_src)
    ref_enc = model.encode(sent_tgt)
    # ensure shapes match for reference
    prev = torch.LongTensor([model.tgt_dict.eos() for _ in ref_enc]).unsqueeze(0)

    net_input = {"net_input": {"src_tokens": enc_src.unsqueeze(0), "src_lengths": [enc_src.shape[0]], "prev_output_tokens": prev}, "target": ref_enc.unsqueeze(0)}
    score = scorer.generate(model.models, net_input)
    # print log_e prob
    print(score[0][0]["score"])