cfahlgren1 / observers

A Lightweight Library for AI Observability
143 stars 12 forks source link

Refactoring proposition for better multi-client support #23

Open hanouticelina opened 2 days ago

hanouticelina commented 2 days ago

Hi @cfahlgren1 and @davidberenstein1957, really nice work, i like the idea! while going through the code, I had a refactoring idea to make it easier to add support for other text generation/chat completion clients (#15 and #6), and I thought I might drop it here! 🤗

Currently, we modify the client's methods directly (see : observers/models/openai.py#L210) which might not work for all clients as some might not support this kind of method replacement. Also, writing the library as a patch over the client SDK would make it more maintainable and less prone to compatibility issues.

So here is what i have in mind — instead of modifying the client directly, we create a simple patch that wraps any client:

observers/models/base.py

...
# imports here
@dataclass
class ChatCompletionRecord(Record):
    """
    Data class for storing chat completion records.
    """
    # same attributes as OpenAIResponseRecord

    @property
    def duckdb_schema(self):
        # same duckdb_schema as in OpenAIResponseRecord
        ...

    def argilla_settings(self, client: "Argilla"):
        # same argilla settings as in OpenAIResponseRecord
        ...

    @property
    def json_fields(self):
        return ["tool_calls", "function_call", "tags", "properties", "raw_response"]

class ChatCompletionObserver:
   """
    Observer that provides a clean interface for tracking chat completions
    Args:
        client (Any):
            The client to use for the chat completions.
        create (Callable[..., Any]):
            The function to use to create the chat completions., eg `chat.completions.create` for OpenAI client.
        format_input (Callable[[Dict[str, Any], Any], Any]):
            The function to use to format the input messages.
        parse_response (Callable[[Any], Dict[str, Any]]):
            The function to use to parse the response.
        store (Optional[Union["DuckDBStore", DatasetsStore]]):
            The store to use to save the records.
    """

    def __init__(
        self,
        client: Any | None,
        create: Callable[..., Any],
        format_input: Callable[[Dict[str, Any], Any], Any],
        parse_response: Callable[[Any], Dict[str, Any]],
        store: Optional[Union["DuckDBStore", DatasetsStore]] = None,
        **kwargs: Any,
    ):
        self.client = client
        self.create_fn = create
        self.format_input = format_input
        self.parse_response = parse_response
        self.store = store or DatasetsStore.connect()
        self.kwargs = kwargs

    @property
    def chat(self) -> Self:
        return self

    @property
    def completions(self) -> Self:
        return self

    def create(
        self,
        inputs: Dict[str, Any],
        tags: Optional[List[str]] = None,
        properties: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Create a completion and store the response in a database"""
        tags = tags or []
        properties = properties or {}
        response = None
        try:
            kwargs = self.handle_kwargs(kwargs)

            input_data = self.format_input(inputs, **kwargs)
            response = self.create_fn(**input_data)
            record = self.parse_response(
                response,
                tags=tags,
                properties=properties,
            )

            self.store.add(record)
            return response
        except Exception as e:
            record = self.parse_response(
                response,
                error=e,
                model=kwargs.get("model"),
            )
            self.store.add(record)
            raise

    def handle_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
        """
        Handle and process keyword arguments for the API call.
        It ensures that any kwargs passed to the method call take precedence over the default ones.
        """
        return {**self.kwargs, **kwargs}

    def __getattr__(self, attr: str) -> Any:
        if attr not in {"create", "chat", "messages"}:
            return getattr(self.client, attr)

        return getattr(self, attr)

class AsyncChatCompletionObserver(ChatCompletionObserver):
    def __init__(
        self,
        client: Any | None,
        create: Callable[..., Any],
        format_input: Callable[[Dict[str, Any], Any], Any],
        parse_response: Callable[[Any], Dict[str, Any]],
        store: Optional[Union["DuckDBStore", DatasetsStore]] = None,
        **kwargs: Any,
    ):
        super().__init__(client, create, format_input, parse_response, store, **kwargs)
        raise NotImplementedError("Async support not implemented yet")

here is how it looks with transformers.pipeline:

in observers/models/transformers.py

...

class TransformersRecord(ChatCompletionRecord):
    @classmethod
    def from_response(
        cls,
        response: Dict[str, Any] = None,
        error: Exception = None,
        **kwargs,
    ) -> Self:
        if not response:
            return cls(finish_reason="error", error=str(error), **kwargs)
        generated_text = response[0]["generated_text"][-1]
        return cls(
            id=str(uuid.uuid4()),
            assistant_message=generated_text.get("content"),
            tool_calls=generated_text.get("tool_calls"),
            raw_response=response,
            **kwargs,
        )

def wrap_transformers(
    client: transformers.TextGenerationPipeline,
    store: Optional[Union[DuckDBStore, DatasetsStore]] = None,
) -> ChatCompletionObserver:
    return ChatCompletionObserver(
        client=client,
        create=client.__call__,
        format_input=lambda inputs, **kwargs: {"text_inputs": inputs, **kwargs},
        parse_response=TransformersRecord.from_response,
        store=store,
    )

in observers/models/openai.py

...
class OpenAIRecord(ChatCompletionRecord):
    @classmethod
    def from_response(
        cls,
        response: openai.types.chat.ChatCompletion = None,
        error: Exception = None,
        **kwargs,
    ) -> Self:
       #same as in OpenAIResponseRecord
       ...

# personal opinion : I'd rather call these functions `from_{client_name}()`
# but let's keep `wrap_{client_name}()` for backward compatibility
def wrap_openai(
    client: openai.OpenAI | openai.AsyncOpenAI,
    store: Optional[Union["DuckDBStore", DatasetsStore]] = None,
) -> ChatCompletionObserver:
    if isinstance(client, openai.AsyncOpenAI):
        return AsyncChatCompletionObserver(
            client=client,
            create=client.chat.completions.create,
            format_input=lambda inputs, **kwargs: {"messages": inputs, **kwargs},
            parse_response=OpenAIRecord.from_response,
            store=store,
        )

    return ChatCompletionObserver(
        client=client,
        create=client.chat.completions.create,
        format_input=lambda inputs, **kwargs: {"messages": inputs, **kwargs},
        parse_response=OpenAIRecord.from_response,
        store=store,
    )

The same pattern works for huggingface_hub.InferenceClient!

WDYT? of course, there are still some things that can be improved but this is basically the initial idea. if that makes sense for you, I'd be happy to open a PR for that (with tests obviously!)

cfahlgren1 commented 2 days ago

This tracks for me! Thanks for the really well detailed proposition! A PR would be awesome 🤗

davidberenstein1957 commented 2 days ago

+1 looks great and helps a lot with the structure of the wrapping we had not worked that much on yet 🥇

hanouticelina commented 10 hours ago

Great! I will open a PR this week!