superduper-io / superduper

Superduper: build end-2-end AI applications and templates using your existing data infrastructure and tools of choice
https://superduper.io
Apache License 2.0
4.8k stars 464 forks source link

[REMOTE-UTIL] Create model-proxy #2444

Open blythed opened 2 months ago

jieguangzhou commented 2 months ago

All the models with serve=True will return the RemoteModel instances, and we can access the original_model with model.original_model

It will only load the original_model when it is accessed.

We can use another name instead of original_model

# Scenario 1: predict first

model = Model(..., serve=True) # return RemoteModel

model.predict() # Use original model to predict

# Scenario 2: apply first

model = Model(..., serve=True) # return RemoteModel

db.apply(model)

model.predict() # Serve the model and use the remote  predict

# Scenario 3: db.load

model = db.load(type_id, identifier, uuid, ...) # Return RemoteModel

model.predict() # Use the remote predict
class Model(Component, metaclass=ModelMeta):
    ....
    def __new__(cls, *args, **kwargs):
        # if we create the instance in decode function
        original_model = None
        # else
        original_model = super().__new__(cls)
        if kwargs.get("serve", False):
            remote_model = RemoteModel(
                type_id=original_model.type_id,
                identifier=original_model.identifier,
                uuid=original_model.uuid,
            )
            remote_model.original_model = original_model
            return remote_model
        return original_model

class RemoteModel(Component):
    type_id: str
    identifier: str
    uuid: str

    def __post_init__(self, db, artifacts):
        self._orginal_model = None
        return super().__post_init__(db, artifacts)

    def predict(self, *args, **kwargs):
        return self._call_predict_function("predict", *args, **kwargs)

    def predict_batches(self, dataset):
        return self._call_predict_function("predict_batches", dataset)

    def _call_predict_function(self, func_name, *args, **kwargs):
        if self.db is None:
            assert self._orginal_model is not None
            return getattr(self._orginal_model, func_name)(*args, **kwargs)

        else:
            return getattr(self.db.compute, func_name)(self.identifier, *args, **kwargs)

    @property.getter
    def original_model(self):
        if self._orginal_model is None:
            self._orginal_model = self.db.load(self.type_id, self.identifier, serve=False)

    @property.setter
    def original_model(self, model):
        return self._orginal_model

    def encode(self):
        if self._orginal_model:
            return self._orginal_model.encode()
        else:
            reference = self.type_id + ":" + self.identifier + ":" + self.uuid
            return "&:component:" + reference

    def pre_create(self, db: Datalayer) -> None:
        if self._orginal_model:
            # We apply the original model first, then the remote model will be skipped in apply
            db.apply(self._orginal_model)