LamaAni / KubernetesJobOperator

An airflow operator that executes a task in a kubernetes cluster, given a kubernetes yaml configuration or an image refrence.
57 stars 8 forks source link

FEATURE: TaskFlow / Decorated Task #96

Open pykenny opened 3 months ago

pykenny commented 3 months ago

Feature description and background

Currently, the Kubernetes operator supports generating XCom output through job output. But as for receiving XCom input, it only provides the most basic approach through argument templating and usage of TaskInstance.xcom.pull() within the template.

Fighting with Jinja templating can be cumbersome sometimes, especially when accessing XCom output from multiple upstream tasks. For instance, when running a Kubernetes job with dynamic arguments (arguments argument) that relies on several upstream tasks to be determined, one need to build a Jinja template that outputs the argument list, and remember to enable render_template_as_native_obj flag for the DAG it's running.

Given Airflow has introduced concept of TaskFlow and been promoting usage of decorated tasks since Airflow 2.0's release, adopting the new paradigm and adding decorated form of KubernetesJobOperator (e.g. @task.kubernetes_job) can be helpful, and use cases like the situation described above can be done more smoothly by passing upstream task instances to the decorated task function directly.

Proposed Solution

For task operator that only runs Kubernetes jobs, it may be difficult to make Implementation similar to @task.kubernetes (a wrapper of Airflow's KuberetesPodOperator) and move task program logics into the decorated function. However, I think it's feasible to make the decorated function return arguments for the KubernetesJobOperator instead. The dynamically generated arguments can then integrated/merged with the KubernetesJobOperator's defaults and arguments passed to decorator header.

For example, supposed we define a decorated task function like this:

@task.kubernetes_job(task_id="my_task_id", image="myimage")
def my_job(options: list[str] | None = None):
    positional_args = ["a", "b", "c"]
    return {"arguments"=(options + positional_args) if options else positional_args}

When a list ["-n", "-f", "source_data.csv"] gets passed to my_job and creates a task instance within a DAG:

my_task = my_job(options=["-n", "-f", "source_data.csv"])

then the created task is equivalent to:

KubernetesJobOperator(
   task_id="my_task_id",
   image="myimage",
   arguments=["-n", "-f", "source_data.csv", "a", "b", "c"]
)

When task output gets passed to options argument, then value of arguments will depend on value of the output.

# Task created from task operator class
upstream_task = PythonOperator(python_callable=upstream_task_f, task_id="upstream")
downstream_task = my_job(options=upstream_task.output)

# Task created from decorated task (TaskFlow)
upstream_task = task(upstream_task_f, task_id="upstream")()
downstream_task = my_job(options=upstream_task)

Example DAG implementation before introducing feature

Say we're having an container image owning a program that transfers data from several sources to the destination at once, where the image's entry point accepts several optional arguments to customize this program:

Now we want to design a DAG pipeline that creates a Kubernetes job and runs the image when triggered, and we decide to collect argument information for the image by setting up three upstream tasks separately:

To pass information from these upstream tasks into the final KubernetesJobOperator task directly, one have to create Jinja template that generates the arguments parameter, and introduce outputs from upstream tasks with TaskInstance.xcom_pull() calls within the template:

from airflow.decorators import dag, task
from airflow_kubernetes_job_operator.kubernetes_job_operator import KubernetesJobOperator
import pendulum

@dag(
    dag_id="transfer_pipeline",
    schedule=None,
    start_date=pendulum.datetime(2000, 1, 1, tz="UTC"),
    catchup=False,
    render_template_as_native_obj=True,
)
def example_dag():
    # Upstream task that evaluates sources with new incoming data
    @task(task_id="get_sources")
    def get_sources_f() -> list[str]:
        from my_package.tasks.example_dag import GetSources

        incoming_data_sources = GetSources().run()

        return incoming_data_sources

    # Upstream task that evaluates amount of data to be pulled
    @task(task_id="get_incoming_data_amount"):
    def get_incoming_data_amount_f() -> int:
        from my_package.tasks.example_dag import GetIncomingDataSummary

        data_amount = GetUpdateSummary().run()

        return  data_amount

    @task(task_id="evaluate_high_loading")
    def evaluate_high_loading_f(amount: int) -> bool:
        threshold = 1_000_000

        return amount > threshold

    # Creates a temporary storage (e.g. S3 bucket) to store intermediate data dumps
    @task(task_id="get_bucket_name") -> str:
    def get_bucket_name_f()
        from my_package.tasks.example_dag import GetBucketName

        bucket_name = GetBucketName().run()

        return bucket_name

    incoming_data_sources = get_sources_f()
    incoming_data_amount = get_incoming_data_amount_f()
    high_loading = evaluate_high_loading_f(incoming_data_amount)
    bucket_name = get_bucket_name_f()

    transfer_job = KubernetesJobOperator(
        task_id="transfer",
        image="myimage",
        jinja_job_arg=True,
        arguments=(
            "["
            "\"--sources\",{{ ti.xcom_pull('get_sources') }}"
            "\"{{'--enable-high-loading-mode' if ti.xcom_pull('evaluate_high_loading') '--disable-high-loading-mode'}}\""
            "\"--bucket\",\"{{ ti.xcom_pull('get_bucket_name') }}\""
            "]"
        )
    )

dag = example_dag()

As an alternative, one can insert additional task in between to organize upstream outputs into single argument list to avoid templating:

from airflow.decorators import dag, task
from airflow_kubernetes_job_operator.kubernetes_job_operator import KubernetesJobOperator
import pendulum

@dag(
    dag_id="transfer_pipeline",
    schedule=None,
    start_date=pendulum.datetime(2000, 1, 1, tz="UTC"),
    catchup=False,
    render_template_as_native_obj=True,
)
def example_dag():
    # Upstream task that evaluates sources with new incoming data
    @task(task_id="get_sources")
    def get_sources_f() -> list[str]:
        from my_package.tasks.example_dag import GetSources

        incoming_data_sources = GetSources().run()

        return incoming_data_sources

    # Upstream task that evaluates amount of data to be pulled
    @task(task_id="get_incoming_data_amount"):
    def get_incoming_data_amount_f() -> int:
        from my_package.tasks.example_dag import GetIncomingDataSummary

        data_amount = GetUpdateSummary().run()

        return  data_amount

    @task(task_id="evaluate_high_loading")
    def evaluate_high_loading_f(amount: int) -> bool:
        threshold = 1_000_000

        return amount > threshold

    # Creates a temporary storage (e.g. S3 bucket) to store intermediate data dumps
    @task(task_id="get_bucket_name") -> str:
    def get_bucket_name_f()
        from my_package.tasks.example_dag import GetBucketName

        bucket_name = GetBucketName().run()

        return bucket_name

    # Organize upstream information and generates final argument list 
    @task("organize_arguments")
    def organize_arguments_f(data_sources: list[str], high_loading: bool, bucket_name: str):
        return [
            "--sources",
            *data_sources,
            (
                "--enable-high-loading-mode"
                if high_loading
                else "--disable-high-loading-mode"
            ),
            "--bucket",
            bucket_name
        ]

    incoming_data_sources = get_sources_f()
    incoming_data_amount = get_incoming_data_amount_f()
    high_loading = evaluate_high_loading_f(incoming_data_amount)
    bucket_name = get_bucket_name_f()

    organized_arguments = organize_arguments_f(
        data_sources=incoming_data_sources,
        high_loading=high_loading,
        bucket_name=bucket_name,
    )

    transfer_job = KubernetesJobOperator(
        task_id="transfer",
        image="myimage",
        jinja_job_arg=True,
        arguments="{{ ti.xcom_pull('organize_arguments') }}"
        # If TaskFlow is supported we can pass the task instance instead:
        # arguments=organized_arguments
    )

dag = example_dag()

Example DAG implementation after introducing feature

The sample code below tries to replicate the example above with the proposed approach. Task ID and image name are assumed to be "fixed" and passed to the decorator header, while args field is dynamically generated within the decorated function:

from airflow.decorators import dag, task
import pendulum

@dag(
    dag_id="transfer_pipeline",
    schedule=None,
    start_date=pendulum.datetime(2000, 1, 1, tz="UTC"),
    catchup=False,
)
def example_dag():
    # Upstream task that evaluates sources with new incoming data
    @task(task_id="get_sources")
    def get_sources_f() -> list[str]:
        from my_package.tasks.example_dag import GetSources

        incoming_data_sources = GetSources().run()

        return incoming_data_sources

    # Upstream task that evaluates amount of data to be pulled
    @task(task_id="get_incoming_data_amount"):
    def get_incoming_data_amount_f() -> int:
        from my_package.tasks.example_dag import GetIncomingDataSummary

        data_amount = GetUpdateSummary().run()

        return  data_amount

    @task(task_id="evaluate_high_loading")
    def evaluate_high_loading_f(amount: int) -> bool:
        threshold = 1_000_000

        return amount > threshold

    # Creates a temporary storage (e.g. S3 bucket) to store intermediate data dumps
    @task(task_id="get_bucket_name") -> str:
    def get_bucket_name_f()
        from my_package.tasks.example_dag import GetBucketName

        bucket_name = GetBucketName().run()

        return bucket_name

    @task.kuberenetes_job(task_id="transfer", image="myimage")
    def transfer_f(data_sources: list[str], high_loading: bool, bucket_name: str):
        return {
            "arguments": [
                "--sources",
                data_sources,
                ("--enable-feature" if upstream_result_b else "--disable-feature"),
                job_target,
            ]
        }

    incoming_data_sources = get_sources_f()
    incoming_data_amount = get_incoming_data_amount_f()
    high_loading = evaluate_high_loading_f(incoming_data_amount)
    bucket_name = get_bucket_name_f()
    transfer_job = transfer_f(
        data_sources=incoming_data_sources,
        high_loading=high_loading,
        bucket_name=bucket_name
    )

dag = example_dag()
LamaAni commented 3 months ago

Hi,

I do love the idea of decorators, but I'm having a hard time understanding the use case?. Can you provide a simple example, e.g.

  1. What is the task we are trying to complete?
  2. This is how it would be implemented today. (Full dag with description)
  3. This is how to implement it with Task flow. (Full dag with description)

Please make sure the example is as short and simple as possible just to illustrate the point, so I understand the problem better. You can use TaskFlow and normal airflow operators to describe it, instead of the the KubernetesJobOperator if it would be easier.

Let me also look into the airflow TaskFlow, since I have not used it, and see if we can produce something similar to its internal decorators. We just may be able to reproduce the operations provided by it if we look into the source code; in general, KubernetesJobOperator uses the airflow Task as a base class.

It may take a while since I am on vacation right now (I would accept PR's :))

WhyMe