replicate / cog

Containers for machine learning
https://cog.run
Apache License 2.0
7.76k stars 542 forks source link

Refactor types and interfaces for "weights" across python/cog/predictor.py #1395

Open vicfryzel opened 9 months ago

vicfryzel commented 9 months ago

predictor.py may benefit from some slight refactoring to how weights are handled. cog.predictor.run_setup() uses a different set of types to represent weights than the rest of the module. This inconsistent typing gives an inconsistent interface to classes that extend cog.BasePredictor. It also makes it harder to refactor weight handling more broadly in the future if needed.

Suggestion

Add a WeightsType member to BasePredictor. Update type-checking to be simpler.

from .types import (File as CogFile, Path as CogPath)
# ...

class BasePredictor(ABC):
    # ...
    WeightsType = Optional[Union[CogFile, CogPath]]
    # ...
    def setup(self, weights: WeightsType = None) -> None:
    # ...

This turns implementations of BasePredictor from:

class ExampleImageClassificationPredictor(BasePredictor):
    def setup(self, weights: Optional[Union[File, Path]] = None):

into:

class ExampleImageClassificationPredictor(BasePredictor):
    def setup(self, weights: WeightsType = None):

This change would need to be pinned to an API-change release, unless the solution can supertype both the new WeightsType and the current Optional[Union[File, Path]] type.

yorickvP commented 9 months ago

We currently use pydantic to look at the weights type you implement to provide the right type at runtime. That means there's no need to ever implement a setup with both of these types.

The example should probably be:

class ExampleImageClassificationPredictor(BasePredictor):
    def setup(self, weights: Optional[Path] = None):