Open zouharvi opened 2 years ago
What is the error? Is it like wrong structure of input {"net_input": sent_src_enc, "target": sent_tgt_enc}?
If it is a transformer with a normal translation task, it uses LanguagePairDataset
.
From its definition .py, you can see its structure
It creates a prev_output_tokens
from target
.
In addition, transformer's forward tells it does not look up target
but prev_output_tokens
. So I guess you may change the key target
to prev_output_tokens
.
One last thing, prev_output_tokens
is already a tensor, not a dictionary.(see decoder's forward)
You better check out what model.encode(sent_tgt)
is and take what you need from it. You may use def merge if it helps.
If your model is not fairseq transformer, or you are using a different dataset. The basic flow it the same, find what data is expected and make a proper batch yourself. Good luck!
You can find a sorta relative example in fairseq_cli/eval_lm.py 's eval_lm You may call this method if loss entropy is what you want finally.
You can perform forced decoding with the following script:
#!/usr/bin/env python3
import torch
from fairseq.sequence_scorer import SequenceScorer
from fairseq.models.transformer import TransformerModel
if __name__ == "__main__":
sent_src = "Hello world!"
sent_tgt = "Hallo Welt!"
model = TransformerModel.from_pretrained(...)
scorer = SequenceScorer(model.tgt_dict)
enc_src = model.encode(sent_src)
ref_enc = model.encode(sent_tgt)
# ensure shapes match for reference
prev = torch.LongTensor([model.tgt_dict.eos() for _ in ref_enc]).unsqueeze(0)
net_input = {"net_input": {"src_tokens": enc_src.unsqueeze(0), "src_lengths": [enc_src.shape[0]], "prev_output_tokens": prev}, "target": ref_enc.unsqueeze(0)}
score = scorer.generate(model.models, net_input)
# print log_e prob
print(score[0][0]["score"])
❓ Forced decoding & decoder score
I'm using the hub interface. It is possible to get the decoder scores of the just generated hypothesis (generated by the model):
Assuming that I already have the source and hypothesis text from some other source, how would I force the decoder to decode the target text and return the logprob? I know of the existence of
SequenceScorer
and--score-reference
but was unable to use them with the hub interface:What's your environment?