replicate / cog

Containers for machine learning
https://cog.run
Apache License 2.0
7.9k stars 550 forks source link

Using the Predictor directly results in FieldInfo errors #1790

Open Clement-Lelievre opened 2 months ago

Clement-Lelievre commented 2 months ago

cog==0.9.6 pydantic==1.10.17 python==3.10.0

Hello,

when writing non end-to-end tests for my cog predictor, in particular tests that instantiate then call the Predictor predict method directly, I ran into errors like TypeError: '<' not supported between instances of 'FieldInfo' and 'FieldInfo'. This is because unlike using the cog CLI or making a HTTP request to a running cog container, when using the Predictor.predict method directly, we're missing the arguments processing layer from cog. So the default args remain pydantic FieldInfo objects (ie the return type of cog.Input) and virtually any basic operation on them will fail.

One possible way to workaround this I wrote below, but is there a better way to achieve this? Thanks,

PREDICTOR_RETURN_TYPE = inspect.signature(Predictor.predict).return_annotation

class TestPredictor(Predictor):
    """A class used only in non end-to-end tests, i.e. those that call directly
    the Predictor. It is required because in the absence of the cog input processing layer
    (that we get using the cog CLI), arguments to `predict()`that are not passed explicitly remain `FieldInfo` objects,
    meaning any basic operation on them will raise an error."""

    def predict(self, **kwargs: dict[str, Any]) -> PREDICTOR_RETURN_TYPE:
        """Processes the input (see main docstring) then call superclass' predict method.

        Returns:
            PREDICTOR_RETURN_TYPE: The output of the superclass `predict` method.
        """
        for kwarg_name, kwarg in kwargs.items():
            kwargs[kwarg_name] = kwarg.default if isinstance(kwarg, FieldInfo) else kwarg
        # pass explicitly all other params
        all_predict_params = inspect.signature(Predictor.predict).parameters
        for param_name, param in all_predict_params.items():
            if param_name != "self" and param_name not in kwargs:
                kwargs[param_name] = (
                    param.default.default if isinstance(param.default, FieldInfo) else param.default
                )

        logger.info(f"Predicting with {kwargs}")

        return super().predict(**kwargs)
mattt commented 2 months ago

Hi @Clement-Lelievre. This is a great question — thanks for opening this issue.

You're right that the default Input values complicate unit testing predict. Cog predictors are intended to be tested end-to-end. (My recommendation would be to use Cog's HTTP interface rather than CLI, as JSON is more expressive and the CLI argument serialization isn't guaranteed to be stable across releases.)

One way to mitigate this is to provide explicit arguments for all parameters, so that the default values aren't used. But an easier option is to put the inference logic into a separate method and test that instead.

Clement-Lelievre commented 2 months ago

Hi @mattt

thanks for your reply.

Most of my e2e tests indeed leverage the HTTP interface. (Yes I have noticed that the CLI has a hard time parsing relatively complex types.) Unfortunately I am in a use case where it's not really an option to explicitly pass all params, so I guess I'll stick with the above test class for now