Galileo-Galilei / kedro-mlflow

A kedro-plugin for integration of mlflow capabilities inside kedro projects (especially machine learning model versioning and packaging)
https://kedro-mlflow.readthedocs.io/
Apache License 2.0
204 stars 34 forks source link

Model signature and examples #562

Open felipemonroy opened 4 months ago

felipemonroy commented 4 months ago

Description

Add support for model signature and examples in dataset MlflowModelTrackingDataset

Context

MLflows support the addition of signature and examples in the models, which includes useful information in the model artifact view

Possible Implementation

At the moment I am passing the signature as a dictionary in the save_args, and using an aditional omegaconf resolver where I am transforming the dictionary to a ModelSignature object. Something similar could be achieved inside MlflowModelTrackingDataset.

_model_signature:
    inputs:
        - type: double
          name: trip_distance
          required: true
        - type: double
          name: trip_duration_minutes
          required: true
    outputs:
        - type: double
          required: true
    params: null

model:
    type: kedro_mlflow.io.models.MlflowModelTrackingDataset
    flavor: mlflow.sklearn
    save_args:
        signature: ${create_model_signature:${_model_signature}}
import json

from mlflow.models.signature import ModelSignature
from omegaconf import DictConfig, OmegaConf

def create_model_signature(model_signature: DictConfig) -> ModelSignature:
    signature_dict = OmegaConf.to_container(model_signature)

    json_signature = {}
    for key, value in signature_dict.items():
        if value is None:
            json_signature[key] = None
        else:
            json_signature[key] = json.dumps(value)

    return ModelSignature.from_dict(json_signature)

Another alternative could be using mlflow infer_signature, however, I am not sure how can you pass the object to infer the schema from. The same thing happens with the example, which according to the documentation could be any of pandas.core.frame.DataFrame, numpy.ndarray, dict, list, csr_matrix, csc_matrix, str, bytes, tuple. At the moment I can pass just a dict or list in the catalog yml.

Galileo-Galilei commented 1 month ago

Hi, sorry for the long reply delay. I love the resolver solution, and I'm inclined to accept it on the spot. Would you like to raise a PR?