keke-tracing / keke

Simple TraceEvent in Python
MIT License
3 stars 2 forks source link

ktrace doesn't work well with generators #8

Open zsol opened 4 months ago

zsol commented 4 months ago

when a generator function is decorated with @ktrace, the trace duration ends when the generator object is constructed. Example:

@ktrace()
def foo() -> Generator[None, None, None]:
    with kev("kev.foo"):
        yield None

Exhausting the above generator will produce two events that don't overlap with each other. I think solving this would be as simple as inspecting the return value of the decorated function and then

if inspect.isgenerator(return_value):
  yield from return_value
else:
  return return_value

The issue unfortunately also applies to async generators, which aren't as simple to solve AFAICT

zsol commented 4 months ago

This seems to work, wdyt?

diff --git a/keke/__init__.py b/keke/__init__.py
index 51d70ee..6440402 100644
--- a/keke/__init__.py
+++ b/keke/__init__.py
@@ -18,7 +18,7 @@ import threading
 import time
 from contextlib import contextmanager
 from functools import wraps
-from inspect import signature
+from inspect import isasyncgenfunction, isgeneratorfunction, signature
 from queue import SimpleQueue
 from typing import (
     Any,
@@ -268,8 +268,7 @@ def ktrace(*trace_args: str, shortname: Union[str, bool] = False) -> Callable[[F
         else:
             name = func.__qualname__

-        @wraps(func)
-        def dec(*args: Any, **kwargs: Any) -> Any:
+        def _get_params(*args: Any, **kwargs: Any) -> dict[str, str]:
             t = get_tracer()
             if t is None:
                 return func(*args, **kwargs)
@@ -283,9 +282,24 @@ def ktrace(*trace_args: str, shortname: Union[str, bool] = False) -> Callable[[F
                 except Exception as e:
                     return repr(e)

-            params = {k: safe_get(k) for k in trace_args}
-            with kev(name, **params):
-                return func(*args, **kwargs)
+            return {k: safe_get(k) for k in trace_args}
+
+        if isasyncgenfunction(func):
+            @wraps(func)
+            async def dec(*args: Any, **kwargs: Any) -> Any:
+                with kev(name, **_get_params(*args, **kwargs)):
+                    async for item in func(*args, **kwargs):
+                        yield item
+        elif isgeneratorfunction(func):
+            @wraps(func)
+            def dec(*args: Any, **kwargs: Any) -> Any:
+                with kev(name, **_get_params(*args, **kwargs)):
+                    yield from func(*args, **kwargs)
+        else:
+            @wraps(func)
+            def dec(*args: Any, **kwargs: Any) -> Any:
+                with kev(name, **_get_params(*args, **kwargs)):
+                    return func(*args, **kwargs)

         return cast(F, dec)