ServiceNow / picard

PICARD - Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models. PICARD is a ServiceNow Research project that was started at Element AI.
https://arxiv.org/abs/2109.05093
Apache License 2.0
341 stars 122 forks source link

Retrieve Prediction Probabilities From Output #102

Closed adamkhakhar closed 2 years ago

adamkhakhar commented 2 years ago

Hi, how can I retrieve the probability values associated with a prediction from the model? For example, when predicting a query, I would like to get the probability of selecting each token that the model uses when performing beam search. I tried looking into Hugging Face's built-ins, but they failed. I added in the arguments, output_scores=True in line 134 in seq2seq/serve_seq2seq.py, but this does not work. When I added the arguments output_scores=True and return_dict_in_generate=True, it throws the following error:

File "/app/seq2seq/utils/pipeline.py", line 78, in __call__
    result = super().__call__(inputs, *args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/pipelines/text2text_generation.py", line 138, in __call__
    result = super().__call__(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1027, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/opt/conda/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1034, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/opt/conda/lib/python3.7/site-packages/transformers/pipelines/base.py", line 944, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/opt/conda/lib/python3.7/site-packages/transformers/pipelines/text2text_generation.py", line 161, in _forward
    out_b = output_ids.shape[0]
AttributeError: 'BeamSearchEncoderDecoderOutput' object has no attribute 'shape'
tscholak commented 2 years ago

Hi @adamkhakhar, please ask this question upstream on the hf issue board.