flyteorg / flyte

Scalable and flexible workflow orchestration platform that seamlessly unifies data, ML and analytics stacks.
https://flyte.org
Apache License 2.0
5.17k stars 550 forks source link

[BUG] ArrayNodes downloads all inputs for every subtasks #5483

Closed EngHabu closed 1 week ago

EngHabu commented 1 week ago

Describe the bug

Simple repro:

@task
def featurization(input: np.array):
   ...

@workflow
def wf_w_np(i: typing.List[np.array]):
    map_task(featurization)(input=i)

Looking at the logs of each featurization task, you will see it downloads all the inputs for the map task.

Expected behavior

Each subtask should download only its input.

Additional context to reproduce

The issue is that this line which converts the literal inputs map into the python equivalent runs against the map_task (with array input type) so it attempts to convert the entire list of inputs to a list of np.array before it finds out later at this line that then calls array_node.execute which then accesses a single input.

Screenshots

No response

Are you sure this issue hasn't been raised already?

Have you read the Code of Conduct?

runllm[bot] commented 1 week ago

Based on Flyte's RFC for ArrayNode, the proposed ArrayNode implementation aims to address issues like the one you described by integrating map tasks directly into Flyte's DAG execution logic. This should ensure that each subtask only accesses its specific input, rather than downloading all inputs.

Would you like more details on this topic?

Sources:


This is an AI-generated response and your feedback is appreciated! Please leave a 👍 if this is helpful and 👎 if it is not.

Mention @runllm to ask follow-up questions. See runllm.com for more!

EngHabu commented 1 week ago

cc @pingsutw @wild-endeavor