DAGWorks-Inc / burr

Build applications that make decisions (chatbots, agents, simulations, etc...). Monitor, trace, persist, and execute on your own infrastructure.
https://burr.dagworks.io
BSD 3-Clause Clear License
1.3k stars 73 forks source link

Typing State for the class-based API does not work #400

Open mdrideout opened 1 month ago

mdrideout commented 1 month ago

Following docs for action level typing for class-based actions does not work.

ref: https://github.com/DAGWorks-Inc/burr/issues/386

Current behavior

Example first action:

class SetInitialPromptAction(Action):
    @property
    def reads(self) -> list[str]:
        return []

    def run(self, state: ApplicationState, prompt: str) -> dict:
        return {"initial_prompt": prompt}

    @property
    def writes(self) -> list[str]:
        return ["initial_prompt"]

    def update(self, result: dict, state: ApplicationState) -> ApplicationState:
        prompt = result["initial_prompt"]
        logger.info(f"Saving prompt to state: {prompt}")
        state.initial_prompt = prompt
        return state

    @property
    def inputs(self) -> list[str]:
        return ["prompt"]

Example second action:

class ExtractSetAction(Action):
    @property
    def reads(self) -> list[str]:
        return ["initial_prompt"]

    def run(self, state: ApplicationState) -> dict:
        logger.info(f"ApplicationState: {state}")

        # Read prompt from state
        prompt = state.initial_prompt
        ...

Logs: ApplicationState: {'initial_prompt': None}

Stack Traces

api | ******************************************************************************** api | ------------------------------------------------------------------- api | Oh no an error! Need help with Burr? api | Join our discord and ask for help! https://discord.gg/4FxBMyzW5n api | ------------------------------------------------------------------- api | > Action: `extract_set` encountered an error!< api | > State (at time of action): api | {'__PRIOR_STEP': 'set_prompt', api | '__SEQUENCE_ID': 1, api | 'initial_prompt': None, api | 'set_from_prompt': None} api | > Inputs (at time of action): api | {'prompt': 'bicep curls with 22 pound dumbells for 21 reps'} api | ******************************************************************************** api | Traceback (most recent call last): api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | result = _run_function( api | ^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | result = function.run(state_to_use, **inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | prompt = state.initial_prompt api | ^^^^^^^^^^^^^^^^^^^^ api | AttributeError: 'State' object has no attribute 'initial_prompt' api | INFO: 192.168.65.1:35000 - "GET /api/extract-set?prompt=bicep%20curls%20with%2022%20pound%20dumbells%20for%2021%20reps HTTP/1.1" 500 Internal Server Error api | ERROR: Exception in ASGI application api | + Exception Group Traceback (most recent call last): api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 77, in collapse_excgroups api | | yield api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 186, in __call__ api | | async with anyio.create_task_group() as task_group: api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 736, in __aexit__ api | | raise BaseExceptionGroup( api | | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception) api | +-+---------------- 1 ---------------- api | | Traceback (most recent call last): api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi api | | result = await app( # type: ignore[func-returns-value] api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__ api | | return await self.app(scope, receive, send) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in __call__ api | | await super().__call__(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in __call__ api | | await self.middleware_stack(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in __call__ api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in __call__ api | | await self.app(scope, receive, _send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in __call__ api | | with collapse_excgroups(): api | | File "/usr/local/lib/python3.11/contextlib.py", line 158, in __exit__ api | | self.gen.throw(typ, value, traceback) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in __call__ api | | response = await self.dispatch_func(request, call_next) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/main.py", line 36, in log_requests api | | response = await call_next(request) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next api | | raise app_exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro api | | await self.app(scope, receive_or_disconnect, send_no_error) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in __call__ api | | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | | await app(scope, receive, sender) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in __call__ api | | await self.middleware_stack(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app api | | await route.handle(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle api | | await self.app(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app api | | await wrap_app_handling_exceptions(app, request)(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | | await app(scope, receive, sender) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app api | | response = await f(request) api | | ^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app api | | raw_response = await run_endpoint_function( api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function api | | return await run_in_threadpool(dependant.call, **values) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool api | | return await anyio.to_thread.run_sync(func, *args) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync api | | return await get_async_backend().run_sync_in_worker_thread( api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread api | | return await future api | | ^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run api | | result = context.run(func, *args) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/api/routes.py", line 53, in extract_set api | | action, result, state = application.run( api | | ^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn api | | return call_fn(*args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | | return fn(app_self, *args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run api | | next(gen) api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate api | | prior_action, result, state = self.step(inputs=inputs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | | return fn(app_self, *args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step api | | out = self._step(inputs=inputs, _run_hooks=True) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step api | | raise e api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | | result = _run_function( api | | ^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | | result = function.run(state_to_use, **inputs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | | prompt = state.initial_prompt api | | ^^^^^^^^^^^^^^^^^^^^ api | | AttributeError: 'State' object has no attribute 'initial_prompt' api | +------------------------------------ api | api | During handling of the above exception, another exception occurred: api | api | Traceback (most recent call last): api | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi api | result = await app( # type: ignore[func-returns-value] api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__ api | return await self.app(scope, receive, send) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in __call__ api | await super().__call__(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in __call__ api | await self.middleware_stack(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in __call__ api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in __call__ api | await self.app(scope, receive, _send) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in __call__ api | with collapse_excgroups(): api | File "/usr/local/lib/python3.11/contextlib.py", line 158, in __exit__ api | self.gen.throw(typ, value, traceback) api | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in __call__ api | response = await self.dispatch_func(request, call_next) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/main.py", line 36, in log_requests api | response = await call_next(request) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next api | raise app_exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro api | await self.app(scope, receive_or_disconnect, send_no_error) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in __call__ api | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | await app(scope, receive, sender) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in __call__ api | await self.middleware_stack(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app api | await route.handle(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle api | await self.app(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app api | await wrap_app_handling_exceptions(app, request)(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | await app(scope, receive, sender) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app api | response = await f(request) api | ^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app api | raw_response = await run_endpoint_function( api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function api | return await run_in_threadpool(dependant.call, **values) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool api | return await anyio.to_thread.run_sync(func, *args) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync api | return await get_async_backend().run_sync_in_worker_thread( api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread api | return await future api | ^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run api | result = context.run(func, *args) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/api/routes.py", line 53, in extract_set api | action, result, state = application.run( api | ^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn api | return call_fn(*args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | return fn(app_self, *args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run api | next(gen) api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate api | prior_action, result, state = self.step(inputs=inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | return fn(app_self, *args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step api | out = self._step(inputs=inputs, _run_hooks=True) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step api | raise e api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | result = _run_function( api | ^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | result = function.run(state_to_use, **inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | prompt = state.initial_prompt api | ^^^^^^^^^^^^^^^^^^^^ api | AttributeError: 'State' object has no attribute 'initial_prompt'

Screenshots

(If applicable)

Steps to replicate behavior

1.

Library & System Information

E.g. python version, burr library version, linux, etc.

burr = { extras = [
  "graphviz",
  "hamilton",
  "streamlit",
  "tracking-client",
  "tracking-server",
], version = "^0.31.1" }

Expected behavior

To work the same as function-based actions

Additional context

Add any other context about the problem here.