DAGWorks-Inc / hamilton

Hamilton helps data scientists and engineers define testable, modular, self-documenting dataflows, that encode lineage/tracing and metadata. Runs and scales everywhere python does.
https://hamilton.dagworks.io/en/latest/
BSD 3-Clause Clear License
1.8k stars 117 forks source link

`@pipe_output`, `@pipe_input` and `@mutate` for async functons #1193

Open elijahbenizzy opened 4 hours ago

elijahbenizzy commented 4 hours ago

Is your feature request related to a problem? Please describe. These don't work:

  1. Async functions decorated with @pipe_output/@pipe_input/@mutate (have not tested all configurations of these)
  2. Async transformations covered by step

Describe the solution you'd like These should work. See the way we do other decorators. E.G. https://github.com/DAGWorks-Inc/hamilton/blob/9ae4183e8ceb98f3393c8d62923fb205cc370eb0/hamilton/function_modifiers/recursive.py#L399

Describe alternatives you've considered We can use a workaround in some cases. E.G. an identity node. See the first comment for it.

Additional context https://hamilton-opensource.slack.com/archives/C03M33QB4M8/p1729156532663429

elijahbenizzy commented 4 hours ago

Workaround for @pipe_output

import asyncio

import pandas as pd

from hamilton import async_driver
from hamilton.function_modifiers import pipe_output, step, hamilton_exclude

async def data_input() -> pd.DataFrame:
    await asyncio.sleep(0.0001)
    return pd.DataFrame({
        "a": [1, 2, 3],
        "b": [4, 5, 6]
    })

def _groupby_a(d: pd.DataFrame) -> pd.DataFrame:
    return d.groupby("a").sum().reset_index()

def _groupby_b(d: pd.DataFrame) -> pd.DataFrame:
    return d.groupby("b").sum().reset_index()

@pipe_output(
    step(_groupby_a).when(groupby="a"),
    step(_groupby_b).when_not(groupby="a"),
)
def data(data_input: pd.DataFrame) -> pd.DataFrame:
    return data_input

@hamilton_exclude
async def main():
    import __main__
    dr = (await async_driver.Builder().with_modules(__main__).with_config(dict(groupby="b")).build())
    results = await dr.execute(["data"])
    print(results)

if __name__ == "__main__":
    asyncio.run(main())