apache / airflow

Apache Airflow - A platform to programmatically author, schedule, and monitor workflows
https://airflow.apache.org/
Apache License 2.0
37.3k stars 14.34k forks source link

MappedTasks don't short-circuit #43883

Open Chais opened 1 week ago

Chais commented 1 week ago

Apache Airflow version

2.10.3

If "Other Airflow 2 version" selected, which one?

No response

What happened?

It seems like tasks created via .expand() do not inherit the downstream tasks of its parent/original. This becomes an issue when we expand a task.short_circuit because short-circuiting no longer works.

What you think should happen instead?

My understanding from the documentation is, that step_two.expand(do=step_one.expand(i=start())) is supposed to be equivalent to:

flowchart LR
start --> A1[step_one_0] --> B1[step_two_0]
start --> A2[step_one_1] --> B2[step_two_1]
start --> A3[step_one_2] --> B3[step_two_2]
start --> A4[…] --> B4[…]
start --> A9[step_one_9] --> B9[step_two_9]

How to reproduce

import random
from typing import List

import pendulum
from airflow.decorators import dag, task

@dag(
    "test_foo",
    schedule=None,
    start_date=pendulum.now(),
    render_template_as_native_obj=True,
    dag_display_name="Test random things",
)
def test_foo():
    @task.python
    def start() -> List[int]:
        return [i for i in range(10)]

    @task.short_circuit
    def step_one(i: int) -> bool:
        print(f"Hello from step {i}")
        return random.random() >= 0.5

    @task.python
    def step_two(do: bool):
        if not do:
            print("Should've been skipped.")
        print("Doing stuff")

    step_two.expand(do=step_one.expand(i=start()))

test_foo()

if __name__ == "__main__":
    test_foo().test()

Obviously approximately 5 of the 10 tasks are bound to return False and short-circuit. If they do the log will read something like this:

Hello from step 5
[2024-11-11 13:31:09,695] {python.py:240} INFO - Done. Returned value was: False
[2024-11-11 13:31:09,697] {python.py:309} INFO - Condition result is False
[2024-11-11 13:31:09,698] {python.py:316} INFO - No downstream tasks; nothing to do.

But checking the task_dict variable in the module, we see that step_one does have step_two set as downstream task and vice-versa, step_two has step_one set as upstream task.

__pydevd_ret_val_dict['factory'].task_dict['step_one'].downstream_task_ids
{'step_two'}
__pydevd_ret_val_dict['factory'].task_dict['step_two'].upstream_task_ids
{'step_one'}

Operating System

24.04.1 LTS (Noble Numbat)

Versions of Apache Airflow Providers

No response

Deployment

Official Apache Airflow Helm Chart

Deployment details

I'm testing this on Standalone

Anything else?

I can get closer to the desired behaviour by using a list comprehension in the dependency, but that can't be done dynamically, I can't iterate PlainXComArgs and it also doesn't produce exactly the desired behaviour.

Are you willing to submit PR?

Code of Conduct

Chais commented 1 week ago

I was able to get the desired behaviour by creating a task group, pulling the return_value XCom from start and manually creating all the tasks in a loop, which I feel shouldn't be necessary:

from random import random
from typing import List

from airflow.decorators import dag, task, task_group
from airflow.models import XCom
from pendulum import now

@dag(
    "test_foo",
    schedule=None,
    start_date=now(),
    dag_display_name="Test random things",
)
def test_foo():
    @task.python
    def start() -> List[int]:
        return [i for i in range(1, 11)]

    @task_group()
    def tg():
        @task.short_circuit
        def step_one(i: int) -> bool:
            print(f"Hello from step {i}")
            return random() >= 0.5

        @task.python(trigger_rule="one_done")
        def step_two(do: bool):
            if not do:
                print("Should've been skipped.")
            print("Doing stuff")

        for i in XCom.get_one(key="return_value", task_id="start", run_id=XCom.run_id):
            step_two(step_one(i))

    start() >> tg()

test_foo()

if __name__ == "__main__":
    test_foo().test()
Chais commented 3 days ago

I'm seeing this on 2.9.3, too.