aws / fmeval

Foundation Model Evaluations Library
http://aws.github.io/fmeval
Apache License 2.0
187 stars 42 forks source link

feat: make sm/br runners easier to subclass #159

Closed franluca closed 9 months ago

franluca commented 9 months ago

Issue: when subclassing Jumpstart/SM/Bedrock model runners one has also to override the __reduce__ function to change the returned class, which is not intuitive

Description of changes: changed return of __reduce__ functions to self.__class__ rather than (e.g.) BedrockRunner, making overriding of reduce optional (and very likely not necessary).

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

franluca commented 9 months ago

Yes correct. Here's an example:

class ClaudeModelRunner(BedrockModelRunner):

    def predict(self, prompt: str) -> Tuple[Optional[str], Optional[float]]:
        """
        Overridden to globally follow the Claude Human: [...] Assistant: convention.
        todo is there a way to avoid doing this and just use the BedrockModelRunner from the library?
        """
        prompt = f"Human: {prompt}\n\n Assistant:"
        return super().predict(prompt)

In fact an even more general implementation of reduce could be the following

    def __reduce__(self):
        """
        Custom serializer method used by Ray when it serializes instances of this
        class in eval_algorithms.util.generate_model_predict_response_for_dataset.
        """
        # serialized_data = (
        #     self._model_id,
        #     self._content_template,
        #     self._output,
        #     self._log_probability,
        #     self._content_type,
        #     self._accept_type,
        # )
        serialized_data = super().__reduce__()[2]
        serialized_data.pop('_extractor', None)
        serialized_data.pop('_composer', None)
        serialized_data.pop('_bedrock_runtime_client', None)
        return self.__class__, tuple(serialized_data.values())

as basically you do not want to serialize these three objects, right?