Description of changes:
Currently, GetModelResponse is a transform that augments the input record with the entire model response payload from calling the predict method of a ModelRunner. The response payload is of the form (model_output, log_prob), where both model_output and log_prob are optional, and whether they are non-null depends on the particular ModelRunner being used.
When initializing a GetModelResponse, one must provide a tuple representing the keys that will be associated with the data in a response payload. Algorithms like SummarizationAccuracy that currently utilize GetModelReponse only care about obtaining model outputs (and not the log probabilities), and thus only provide a key for the model output when constructing GetModelResponse instances.
The bug is that GetModelResponse is currently requiring the response payload to conform to the format of the response payload keys, instead of the other way around. For example, if a GetModelResponse is initialized with the response key tuple (model_output_key, ), then its __call__ method will raise an error if the model's predict method returns a payload that contains log probabilities.
assert_condition(
len(model_response) == len(response_key_tuple),
f"The number of elements in model response {model_response} "
f"does not match number of response keys in {response_key_tuple}.",
)
This PR renames GetModelResponse to GetModelOutputs and fixes the bug above. Now, GetModelOutputs is responsible solely for extracting the model_output portion of the (model_output, log_probability)predict response, and more importantly, the __call__ logic does not make any assumptions about the format of the response payload.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Description of changes: Currently,
GetModelResponse
is a transform that augments the input record with the entire model response payload from calling thepredict
method of aModelRunner
. The response payload is of the form(model_output, log_prob)
, where bothmodel_output
andlog_prob
are optional, and whether they are non-null depends on the particularModelRunner
being used.When initializing a
GetModelResponse
, one must provide a tuple representing the keys that will be associated with the data in a response payload. Algorithms likeSummarizationAccuracy
that currently utilizeGetModelReponse
only care about obtaining model outputs (and not the log probabilities), and thus only provide a key for the model output when constructingGetModelResponse
instances.The bug is that
GetModelResponse
is currently requiring the response payload to conform to the format of the response payload keys, instead of the other way around. For example, if aGetModelResponse
is initialized with the response key tuple(model_output_key, )
, then its__call__
method will raise an error if the model'spredict
method returns a payload that contains log probabilities.This PR renames
GetModelResponse
toGetModelOutputs
and fixes the bug above. Now,GetModelOutputs
is responsible solely for extracting themodel_output
portion of the(model_output, log_probability)
predict
response, and more importantly, the__call__
logic does not make any assumptions about the format of the response payload.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.