airbnb / chronon

Chronon is a data platform for serving for AI/ML applications.
Apache License 2.0
717 stars 44 forks source link

CHIP-9: Support Model-based Transformations in Join & Chaining #757

Open hzding621 opened 5 months ago

hzding621 commented 5 months ago

CHIP-9: Support Model-based Transformations in Join & Chaining

Problem Statement

Model Inference is an important primitive form of transform function that ML practitioners use in creating feature pipelines. The most popular example is embeddings as ML features, where the output of an Embedding Model (usually a DNN model) is used as input features for a downstream model. Once created (trained), the (Embedding) Model can be treated as a special case of the general row-to-row transformation function, and can be plugged anywhere into a feature pipeline.

This CHIP will add support for model-based transformations in Chronon via extension to Join API. This, combined with existing chaining support for join, will allow Chronon users to build complex feature pipelines that include model-based transformation. Of course this will also cover both offline backfills and online serving just like regular joins.

Requirements

Non-Requirements

Join API Changes

We call this the Model Transform API, which extends the current Chronon join to introduce a new model_transform section that handles model inference after the generation of raw feature data. Model transform takes place after the execution of join logic completes, and acts as another round of transformation which calls into an external model inference (either batch or online) engine to retrieve the model inference output.

Screenshot 2024-05-01 at 4 04 48 PM

Model

Core to the Model Transform API is the Model definition. A Model contains all parameters required to invoke a model backend to run model inference. Note that the model backend implementation will not live in Chronon open source but in an impl-specific code base. The responsibility of Chronon here is to handle the end-to-end orchestration from raw data into final features. It also serves as the central repository for feature metadata.


listing_model = Model(
    inference_spec=InferenceSpec(
        model_backend="<<model-service>",
        model_backend_params={
            "model_name": "search_listing_tower",
            "model_version": "v1",
        }
    ),
    input_schema=[
        ("numeric_features", DataType.LIST(DataType.DOUBLE)),
        ("categorical_features", DataType.LIST(DataType.LONG)),
    ],
    output_schema=[
        ("listing_embeddings", DataType.LIST(DataType.DOUBLE))
    ]
)

query_model = Model(
    inference_spec=InferenceSpec(
        model_backend="<<model-service>",
        model_backend_params={
            "model_name": "search_query_tower",
            "model_version": "v1",
        }
    ),
    input_schema=[
        ("numeric_features", DataType.LIST(DataType.DOUBLE)),
        ("categorical_features", DataType.LIST(DataType.LONG)),
    ],
    output_schema=[
        ("query_embeddings", DataType.LIST(DataType.DOUBLE))
    ]
)

Model Transform

We will introduce ModelTransform as another section in the Join definition. During orchestration, this step runs after derivations and its output becomes the new join output.

ModelTransform contains the core Model definition, as well as some additional join-level parameters for mappings and formatting:

import joins.search.embeddings.utils

listing_tower_join = Join(
    left=utils.listing_driver,
    right_parts=utils.listing_features_v1,
    model_transform=ModelTransform(
        model=listing_model,
        output_mappings={
            "embeddings": "listing_embeddings"
        },
        passthrough_fields=["gb1_feat1", "gb2_feat2"]
    ),
    online=False,
    offline_schedule='@daily'
)

query_tower_join = Join(
    left=utils.query_driver,
    right_parts=utils.query_features,
    model_transform=ModelTransform(
        model=query_model,
        output_mappings={
            "embeddings": "query_embeddings"
        },
    passthrough_fields=["gb1_feat1", "gb2_feat2"]
    ),
    online=True,
    offline_schedule='@never'
)

Model Backend APIs

Model backend will need to implement the following APIs, which Chronon will invoke during orchestration.

def registerModel(model: Model): Future[RegistrationResponse]
// Send any required metadata to the model backend and prepare it for model inference. 

def registerModelTransform(join: Join): Future[RegistrationResponse]
// Send any required metadata to the model backend and prepare it for (batch) model inference. 

def runModelBatchJob(join: Join, start_ds: String, end_ds: String): Future[JobId]
// Run a batch model inference job for a given join.

def getModelBatchJobStatus(jobId: JobId, start_ds: String, end_ds: String): Future[JobStatus]
// Get the status of a batch model inference job.

def runModelInference(join: Join, inputs: Seq[Map[String, AnyRef]]): Future[Seq[Map[String, AnyRef]]]
// Run online model inference which returns the model output directly.
// Will be used in fetcher.fetchJoin

Orchestration Topology

Metadata Registration and Validation

Screenshot 2024-05-01 at 4 09 05 PM

Join Backfill

Screenshot 2024-05-01 at 4 10 17 PM

Join Fetching

Screenshot 2024-06-01 at 11 57 45 PM

Orchestration Details

Join Level Operations

Analyzer

Join Backfill

Fetcher

Model Metadata Upload

Group By Level Operations (for Chaining)

Below are related to Chaining, where the output of a Join with Model Transform is used as a JoinSource in a downstream GroupBy, which can be either a batch GroupBy or a streaming GroupBy.

Group By Upload

Group By Streaming

nikhilsimha commented 5 months ago
def runModelInference(join: Join, inputs: Map[String, AnyRef]): Future[Map[String, AnyRef]]

This should be instead a batch / multi method

def runModelInference(join: Join, inputs: Seq[Map[String, AnyRef]]): Future[Seq[Map[String, AnyRef]]]