pipecat-ai / pipecat

Open Source framework for voice and multimodal conversational AI
BSD 2-Clause "Simplified" License
2.03k stars 109 forks source link

[Feature Request] Function Calling Integration with Google Gemini #265

Open gaceladri opened 3 days ago

gaceladri commented 3 days ago

Description: To enhance the capabilities and interactivity of our application, it would be highly beneficial to integrate function calling with the Google Gemini Chatbot. This feature would enable the chatbot to invoke predefined functions based on user inputs and context, allowing for more dynamic and functional conversations According to the Gorilla leaderboard, this would provide performance much better than GPT-3.5 and at a lower cost than GPT-4.

I attempted this integration myself but was unsuccessful:

#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import json
from typing import Callable, List

from loguru import logger

from pipecat.frames.frames import (
    Frame,
    LLMFullResponseEndFrame,
    LLMFullResponseStartFrame,
    LLMMessagesFrame,
    LLMResponseEndFrame,
    LLMResponseStartFrame,
    TextFrame,
    VisionImageRawFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService

try:
    import google.ai.generativelanguage as glm
    import google.generativeai as gai
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error(
        "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable."
    )
    raise Exception(f"Missing module: {e}")

class GoogleLLMService(LLMService):
    """This class implements inference with Google's AI models

    This service translates internally from OpenAILLMContext to the messages format
    expected by the Google AI model. We are using the OpenAILLMContext as a lingua
    franca for all LLM services, so that it is easy to switch between different LLMs.
    """

    def __init__(self, api_key: str, model: str = "gemini-1.5-flash-latest", tools: List[Callable] = None, **kwargs):
        super().__init__(**kwargs)
        gai.configure(api_key=api_key)
        self._tools = tools or []
        self._client = gai.GenerativeModel(model, tools=self._tools)

    def can_generate_metrics(self) -> bool:
        return True

    def _get_messages_from_openai_context(self, context: OpenAILLMContext) -> List[glm.Content]:
        openai_messages = context.get_messages()
        google_messages = []

        for message in openai_messages:
            role = message["role"]
            content = message["content"]
            if role == "system":
                role = "user"
            elif role == "assistant":
                role = "model"

            parts = [glm.Part(text=content)]
            if "mime_type" in message:
                parts.append(
                    glm.Part(inline_data=glm.Blob(mime_type=message["mime_type"], data=message["data"].getvalue()))
                )
            google_messages.append({"role": role, "parts": parts})

        return google_messages

    async def _async_generator_wrapper(self, sync_generator):
        for item in sync_generator:
            yield item
            await asyncio.sleep(0)

    async def _process_context(self, context: OpenAILLMContext):
        await self.push_frame(LLMFullResponseStartFrame())
        try:
            logger.debug(f"Generating chat: {context.get_messages_json()}")

            messages = self._get_messages_from_openai_context(context)

            await self.start_ttfb_metrics()

            response = self._client.generate_content(
                messages, generation_config=gai.GenerationConfig(temperature=0), tools=context.tools, stream=True
            )

            await self.stop_ttfb_metrics()

            async for chunk in self._async_generator_wrapper(response):
                try:
                    for candidate in chunk.candidates:
                        if candidate.content:
                            if candidate.content.parts:
                                for part in candidate.content.parts:
                                    if part.text:
                                        await self.push_frame(LLMResponseStartFrame())
                                        await self.push_frame(TextFrame(part.text))
                                        await self.push_frame(LLMResponseEndFrame())
                        if candidate.function_call:
                            for function_call in candidate.function_call:
                                function_name = function_call.name
                                arguments = function_call.args
                                result = await self.call_function(function_name, arguments)

                                # Add function call and result to context
                                context.add_message({"role": "assistant", "content": f"Function call: {function_name}"})
                                context.add_message(
                                    {"role": "function", "name": function_name, "content": json.dumps(result)}
                                )

                                # Send function response back to the model
                                function_response = glm.Content(
                                    parts=[
                                        glm.FunctionResponse(
                                            name=function_name, response={"content": json.dumps(result)}
                                        )
                                    ]
                                )

                                # Re-process context with function result
                                context.add_message(function_response)
                                await self._process_context(context)
                                return

                except Exception as e:
                    # Google LLMs seem to flag safety issues a lot!
                    if chunk.candidates[0].finish_reason == 3:
                        logger.debug(f"LLM refused to generate content for safety reasons - {messages}.")
                    else:
                        logger.error(f"{self} error: {e}")

        except Exception as e:
            logger.error(f"{self} exception: {e}")
        finally:
            await self.push_frame(LLMFullResponseEndFrame())

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        await super().process_frame(frame, direction)

        context = None

        if isinstance(frame, OpenAILLMContextFrame):
            context: OpenAILLMContext = frame.context
        elif isinstance(frame, LLMMessagesFrame):
            context = OpenAILLMContext.from_messages(frame.messages)
        elif isinstance(frame, VisionImageRawFrame):
            context = OpenAILLMContext.from_image_frame(frame)
        else:
            await self.push_frame(frame, direction)

        if context:
            await self._process_context(context)