apache / airflow

Apache Airflow - A platform to programmatically author, schedule, and monitor workflows
https://airflow.apache.org/
Apache License 2.0
36.27k stars 14.08k forks source link

ONNX Model Inference Operator #41702

Open Faakhir30 opened 2 weeks ago

Faakhir30 commented 2 weeks ago

Description

ONNX (Open Neural Network Exchange) provides cross-platform compatibility

An operator that can run inference using ONNX models, ideal for deploying machine learning models in a standardized format can provide us with direct model invocation.

this can be solved using a pythonOperator ofc as onnxruntime can be executed with pythonruntime, but this can also be built into airflow to minimize work, a simple onnx operator structure would be something like:


import onnxruntime as ort
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime

def run_onnx_inference():
    # Load the ONNX model
    model_path = '/path/to/your/model.onnx'
    session = ort.InferenceSession(model_path)

    # Prepare input data
    input_name = session.get_inputs()[0].name
    input_data = {"your_input_key": your_input_data}

    # Run inference
    result = session.run(None, {input_name: input_data})
    print(result)

# Define the DAG
with DAG(
    dag_id='onnx_inference_dag',
    start_date=datetime(2023, 1, 1),
    schedule_interval='@once'
) as dag:

    # Define the task
    inference_task = PythonOperator(
        task_id='onnx_inference_task',
        python_callable=run_onnx_inference
    )

Looking frwd to any suggestions.

Use case/motivation

A direct support of onnx with Airflow's DAG-based orchestration can manage the entire lifecycle of data processing and model inference in one place, providing a more cohesive and manageable workflow.

Related issues

No response

Are you willing to submit a PR?

Code of Conduct

boring-cyborg[bot] commented 2 weeks ago

Thanks for opening your first issue here! Be sure to follow the issue template! If you are willing to raise PR to address this issue please do so, no need to wait for approval.

Rohanberiwal commented 2 weeks ago

Hi , i have worked on this issue from past two days and I came up with a solution . I made certin chnage in the exisiting code and added the execute function inside the operation class that does the same work that your run_onnx_intefence() does . Please see this code and tell me if the code anywhere matches the frequency of your expections .


import onnxruntime as ort
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow import DAG
from datetime import datetime

class ONNXInferenceOperator(BaseOperator):
    @apply_defaults
    def __init__(self, model_path: str, input_data: dict, *args, **kwargs):
        super(ONNXInferenceOperator, self).__init__(*args, **kwargs)
        self.model_path = model_path
        self.input_data = input_data

    def execute(self, context):
        session = ort.InferenceSession(self.model_path)
        input_name = session.get_inputs()[0].name
        result = session.run(None, {input_name: self.input_data})
        self.log.info(f"Inference result: {result}")
        return result

with DAG(
    dag_id='onnx_inference_dag',
    start_date=datetime(2023, 1, 1),
    schedule_interval='@once',
    catchup=False
) as dag:

    inference_task = ONNXInferenceOperator(
        task_id='onnx_inference_task',
        model_path='/path/to/your/model.onnx',
        input_data={"your_input_key": [[1.0, 2.0, 3.0]]}
    )

    inference_task
Rohanberiwal commented 2 weeks ago

Hi I was expecting a reply from you , whenver you see this do let me know . Thank you .

eladkal commented 2 weeks ago

Hi @Rohanberiwal @Faakhir30 Airflow doesn't have Onnx provider thus if you'd like to add it to Airflow you need to follow the protocol of adding new provider, Most of providers are managed by the community rather than by Airflow.

Rohanberiwal commented 2 weeks ago

Yes sir , i will read that protocal and I will get back with a solution as soon as possible . Thank you for your reply .

Rohanberiwal commented 1 week ago

ONNX Inference Operator for Apache Airflow

Description

The ONNXInferenceOperator is a custom operator designed for running inference using ONNX models within an Apache Airflow DAG. This operator leverages the onnxruntime library to load an ONNX model and perform inference on provided input data. The results of the inference are logged and returned.

Components

  1. ONNXInferenceOperator: A custom Airflow operator that initializes with the path to the ONNX model and the input data. It performs inference in the execute method and logs the results.

  2. run_onnx_inference: A helper function that demonstrates how to run inference using the onnxruntime library directly within a PythonOperator. This function is provided as an alternative approach to using the custom operator.

  3. DAG Definition: Defines an Airflow DAG named onnx_inference_dag that schedules the inference task to run once.

Code


import onnxruntime as ort
from airflow import DAG
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.operators.python import PythonOperator
from datetime import datetime

class ONNXInferenceOperator(BaseOperator):
    @apply_defaults
    def __init__(self, model_path: str, input_data: dict, *args, **kwargs):
        super(ONNXInferenceOperator, self).__init__(*args, **kwargs)
        self.model_path = model_path
        self.input_data = input_data

    def execute(self, context):
        session = ort.InferenceSession(self.model_path)
        input_name = session.get_inputs()[0].name
        result = session.run(None, {input_name: self.input_data})
        self.log.info(f"Inference result: {result}")
        return result

def run_onnx_inference():
    model_path = '/path/to/your/model.onnx'
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    input_data = {"your_input_key": [[1.0, 2.0, 3.0]]}
    result = session.run(None, {input_name: input_data})
    print(result)

with DAG(
    dag_id='onnx_inference_dag',
    start_date=datetime(2023, 1, 1),
    schedule_interval='@once',
    catchup=False
) as dag:

    inference_task = ONNXInferenceOperator(
        task_id='onnx_inference_task',
        model_path='/path/to/your/model.onnx',
        input_data={"your_input_key": [[1.0, 2.0, 3.0]]}
    )
Rohanberiwal commented 1 week ago

Sir I would like to know more about the official process of gertting accpeted and work for the airflow , I have made the solution and read the protocal but where shoudl I have to raise a vote , so I can have a comversation with the people and they accept me .
Should I add teh lable of new use in the above proposed solution ?

potiuk commented 1 week ago

It's all explained there - including (as of recently) links to examples where others attempted to propose their providers: https://github.com/apache/airflow/blob/main/PROVIDERS.rst#accepting-new-community-providers

Note - taht It's rather unll