crate / mlflow-cratedb

MLflow adapter for CrateDB.
Apache License 2.0
1 stars 1 forks source link

"registered_models" requires unique constraint polyfill on "name" #46

Closed andnig closed 1 year ago

andnig commented 1 year ago

Issue: We are not able to update models using CrateDB as MLflow backend.

With MLflow, when using "log_model" to save models in the tracking backend. Calling log_model twice with the same model name creates a new model version. However, when using crate as backend, calling log_model twice with the same model name results in

except sqlalchemy.exc.SQLAlchemyError as e:
    raise MlflowException(message=e, error_code=BAD_REQUEST)
    mlflow.exceptions.MlflowException: UPDATE statement on table 'registered_models' expected to update 1 row(s); 2 were matched.

Reason: mlflow expects uniqueness per model name, as defined here:

Proposal: Add uniqueness polyfill to SqlRegisteredModel, similar to

Repro steps:

  1. Have local crate running on localhost:4200 with default creds.

  2. Create a env with Python 3.10

  3. Install deps

    pip install mlflow_cratedb
    pip install pycaret[analysis,models,tuner,parallel,test]
    mlflow-cratedb server --backend-store-uri="${CRATEDB_SQLALCHEMY_URL}" --dev
    export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow"
  4. Run this script twice. The second time, it will fail.

import os
import time

import numpy as np
import pandas as pd
from crate import client
from mlflow import get_tracking_uri
from mlflow.models import infer_signature
from mlflow.sklearn import log_model
from pycaret.time_series import blend_models, compare_models, finalize_model, save_model, setup, tune_model

import mlflow_cratedb  # noqa: F401

def connect_database():
    Connect to CrateDB, and return database connection object.
    dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200")
    return client.connect(dburi)

def table_exists(table_name: str) -> bool:
    Check if database table exists.
    conn = connect_database()
    cursor = conn.cursor()
    sql = (
        f"SELECT table_name FROM information_schema.tables "  # noqa: S608
        f"WHERE table_name = '{table_name}' AND table_schema = CURRENT_SCHEMA"
    rowcount = cursor.rowcount
    return rowcount > 0

def data_available(table_name: str) -> bool:
    Check if data is available in database table.
    conn = connect_database()
    cursor = conn.cursor()
    sql = f"SELECT count(*) FROM {table_name}"  # noqa: S608
    rowcount = cursor.fetchone()[0]
    return rowcount > 0

def import_data(data_table_name: str):
    Download Real-world sales forecasting benchmark data, and load into database.

    target_data = pd.read_csv(
    related_data = pd.read_csv(
    related_data.columns = ["item", "org", "date", "unit_price"]
    data = target_data.merge(related_data, on=["item", "org", "date"])
    data["total_sales"] = data["unit_price"] * data["quantity"]
    data["date"] = pd.to_datetime(data["date"])

    # Split the data into chunks of 1000 rows each for better insert performance
    chunk_size = 1000
    chunks = np.array_split(data, int(len(data) / chunk_size))

    # Insert the data into CrateDB
    with connect_database() as conn:
        cursor = conn.cursor()
        # Create the table if it doesn't exist
        cursor.execute(f"""CREATE TABLE IF NOT EXISTS {data_table_name}
            ("item" TEXT,
            "org" TEXT,
            "date" TIMESTAMP,
            "quantity" BIGINT,
            "unit_price" DOUBLE PRECISION,
            "total_sales" DOUBLE PRECISION)""")

        # Insert the data in chunks
        for chunk in chunks:
                f"""INSERT INTO {data_table_name}
                (item, org, date, quantity, unit_price, total_sales)
                VALUES (?, ?, ?, ?, ?, ?)""",  # noqa: S608
                list(chunk.itertuples(index=False, name=None)),

def read_data(table_name: str) -> pd.DataFrame:
    Read data from database into pandas DataFrame.

    query = f"""
                DATE_TRUNC('month', date) as month,
                SUM(total_sales) AS total_sales
            FROM {table_name}
            GROUP BY month
            ORDER BY month
    with connect_database() as conn:
        data = pd.read_sql(query, conn)

    data["month"] = pd.to_datetime(data["month"], unit="ms")
    data.sort_values(by=["month"], inplace=True)
    return data

def run_experiment(data: pd.DataFrame):
    Run experiment on DataFrame, using PyCaret. Track it using MLflow.
    The mlflow tracking is automatically executed by PyCaret.

    # creating a blend of 3 models, which perform best on MASE metric
    pycaret_setup = setup(data,

    best3 = compare_models(sort="MASE", n_select=3)
    tuned_models = [tune_model(i) for i in best3]
    blended = blend_models(estimator_list=tuned_models, optimize="MASE")
    best_model = finalize_model(blended)

    # saving the model to disk
    if not os.path.exists("model"):
    save_model(best_model, 'model/crate-salesforecast')

    # Create a name for the model
    timestamp = int(time.time())

    # registering the model with mlflow, but only if MLFLOW_TRACKING_URI is
    # set to a tracking server
    if not get_tracking_uri().startswith("file://"):
        y_pred = best_model.predict()
        signature = infer_signature(None, y_pred)
        print(# noqa: T201
            "MLFLOW_TRACKING_URI is not set to a tracking server, "
            "so the model will not be registered with mlflow")

def main():
    Provision dataset, and run experiment.

    # Table name where the actual data is stored.
    data_table = "sales_data_for_forecast"

    # Provision data to operate on, only once.
    if not table_exists(data_table):

        # Wait until table is ready.
        i = 0
        while not data_available(data_table) and i < 5:
            i += 1
        if i == 5 and not data_available(data_table):
            raise Exception("Data is not available in database table.")

    # Read data into pandas DataFrame.
    data = read_data(data_table)
    # Run experiment on data.

if __name__ == "__main__":
amotl commented 1 year ago

Dear Andreas,

thanks a stack for reporting this flaw. We think is has been fixed with GH-47 / 72b80689120, which is already part of v2.7.1. Can you validate that it works for you now?

With kind regards, Andreas.

andnig commented 1 year ago

Happy to confirm that it works with the latest mlflow cratedb. Kudos for the timely fix. image image