google-gemini / generative-ai-python

The official Python library for the Google Gemini API
https://pypi.org/project/google-generativeai/
Apache License 2.0
1.62k stars 322 forks source link

Async Function Support for Tools Parameter in GenerativeModel #632

Open somwrks opened 1 week ago

somwrks commented 1 week ago

Description of the change

This change implements proper async function handling in the GenerativeModel class by modifying the CallableFunctionDeclaration and FunctionLibrary classes. The implementation adds support for detecting and properly awaiting async functions when they are passed as tools, resolving runtime errors related to unhandled coroutines.

  1. Primary Solution (Using asyncio):

    class CallableFunctionDeclaration(FunctionDeclaration):
    def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None, function: Callable[..., Any]):
        super().__init__(name=name, description=description, parameters=parameters)
        self.function = function
        self.is_async = inspect.iscoroutinefunction(function)
    
    async def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse:
        try:
            result = await self.function(**fc.args) if self.is_async else self.function(**fc.args)
            return protos.FunctionResponse(
                name=fc.name, 
                response={"result": result} if not isinstance(result, dict) else result
            )
        except Exception as e:
            return protos.FunctionResponse(
                name=fc.name, 
                response={"error": str(e), "type": type(e).__name__}
            )
  2. Alternative Solution (Custom Implementation):

    
    class AsyncFunctionDeclaration:
    def __init__(self, function: Callable[..., Any]):
        self.function = function
        self.is_async = inspect.iscoroutinefunction(function)
    
    def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse:
        try:
            if self.is_async:
                # Manual coroutine handling without asyncio
                result = self.function(**fc.args)
                if inspect.isawaitable(result):
                    result = result.__await__().__next__()  # Manual await simulation
            else:
                result = self.function(**fc.args)
    
            return protos.FunctionResponse(
                name=fc.name,
                response={"result": result} if not isinstance(result, dict) else result
            )
        except Exception as e:
            return protos.FunctionResponse(
                name=fc.name,
                response={"error": str(e), "type": type(e).__name__}
            )

class AsyncTool: def init(self, function_declarations: Union[Callable[..., Any], dict[str, Callable[..., Any]]]): if isinstance(function_declarations, Callable): self.function_declarations = [AsyncFunctionDeclaration(function_declarations)] elif isinstance(function_declarations, dict): self.function_declarations = [AsyncFunctionDeclaration(f) for f in function_declarations.values()] else: raise ValueError("function_declarations must be a callable or a dictionary of callables")

    self._index = {fd.function.__name__: fd for fd in self.function_declarations}

def __getitem__(self, name: str | protos.FunctionCall) -> AsyncFunctionDeclaration:
    if not isinstance(name, str):
        name = name.name
    return self._index[name]

def __call__(self, fc: protos.FunctionCall) -> protos.Part:
    declaration = self[fc]
    response = declaration(fc)
    return protos.Part(function_response=response)


Key modifications include:
- Have used [asyncio library](https://docs.python.org/3/library/asyncio.html) to implement asynchronous functionality, this can however, be also done without using any library and creating manual classes to handle asynchronous tool functions separately  
- Added async function detection using `inspect.iscoroutinefunction()`
- Implemented async execution in `CallableFunctionDeclaration.__call__`
- Added event loop handling for async functions in `FunctionLibrary`
- Improved error handling for both sync and async functions

## Motivation
The current implementation fails to properly handle async functions when passed as tools to the GenerativeModel, resulting in runtime errors such as "coroutine was never awaited" and incorrect protobuf message conversion. This change is required to enable developers to use async functions with the GenerativeModel's tools parameter, allowing integration with asynchronous APIs and services.

## Type of change
Bug fix, Feature Request

## Checklist
- [x] I have performed a self-review of my code.
- [x] I have added detailed comments to my code where applicable.
- [x] I have verified that my change does not break existing code.
- [x] My PR is based on the latest changes of the main branch (if unsure, please run `git pull --rebase upstream main`).
- [x] I am familiar with the [Google Style Guide](https://google.github.io/styleguide/) for the language I have coded in.
- [x] I have read through the [Contributing Guide](https://github.com/google/generative-ai-python/blob/main/CONTRIBUTING.md) and signed the [Contributor License Agreement](https://cla.developers.google.com/about).
google-cla[bot] commented 1 week ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

MarkDaoust commented 1 week ago

Thanks!

Note we do need you to sign the CLA before we can move the PR farther along.

somwrks commented 1 week ago

Thanks!

Note we do need you to sign the CLA before we can move the PR farther along.

appreciate it! yeah i saw the notification, signed it rightaway🫡

somwrks commented 6 days ago

Hi! This is to remind that i have signed The CLA Form

MarkDaoust commented 6 days ago

Thanks for the reminder! I'll review today.

somwrks commented 6 days ago

Okay, I don't think this works yet.

I think there are two ways to fix this:

.1

I think we need to have the two types, sync and async, and then here (in the async function handler) we need to check the type of the callable and await it, or not:

https://github.com/google-gemini/generative-ai-python/blob/0e5c5f25fe4ce266791fa2afb20d17dee780ca9e/google/generativeai/generative_models.py#L754

Or .2

await it or not... the other option is use asyncio.to_thread to make all functions awaitable in the async function handler.

yes i agree with the first approach, i was essentially working with my project which is basically a discord bot working with different agents to automate actions. That is where i found this bug or thing that it won't allow async functions to work well.

I'll update the changes and create another request from google collab with demo example aswellas my main project which is significantly larger.

Does that sound good?

somwrks commented 5 days ago

approached this with different approach of handling async and sync functions in the beginning by checking the type and then running a separate nesting async loop function for each tool

class CallableFunctionDeclaration(FunctionDeclaration):
    def __init__(
        self, 
        *, 
        name: str, 
        description: str, 
        parameters: dict[str, Any] | None = None,
        function: Callable[..., Any],
    ):
        super().__init__(name=name, description=description, parameters=parameters)
        self.function = function
        self.is_async = inspect.iscoroutinefunction(function)

    def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse:
        """Handles both sync and async function calls transparently"""
        try:
            # Get or create event loop
            try:
                loop = asyncio.get_running_loop()
            except RuntimeError:
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)

            # Execute function based on type
            if self.is_async:
                result = loop.run_until_complete(self._run_async(fc))
            else:
                result = self.function(**fc.args)

            # Format response
            if not isinstance(result, dict):
                result = {"result": result}
            return protos.FunctionResponse(name=fc.name, response=result)
        except Exception as e:
            return protos.FunctionResponse(
                name=fc.name,
                response={"error": str(e), "type": type(e).__name__}
            )

    async def _run_async(self, fc: protos.FunctionCall):
        """Helper method to run async functions"""
        return await self.function(**fc.args)
    def __init__(self, tools: Iterable[ToolType]):
        tools = _make_tools(tools)
        self._tools = list(tools)
        self._index = {}

        for tool in self._tools:
            for declaration in tool.function_declarations:
                name = declaration.name
                if name in self._index:
                    raise ValueError(
                        f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. "
                        "Each `FunctionDeclaration` must have a unique name. Please use a different name."
                    )
                self._index[declaration.name] = declaration

    def __getitem__(
        self, name: str | protos.FunctionCall
    ) -> FunctionDeclaration | protos.FunctionDeclaration:
        if not isinstance(name, str):
            name = name.name
        return self._index[name]

    def __call__(self, fc: protos.FunctionCall) -> protos.Part:
        declaration = self[fc]
        if not callable(declaration):
            return None

        response = declaration(fc)
        if response is None:
            return None

        return protos.Part(function_response=response)

    def to_proto(self):
        return [tool.to_proto() for tool in self._tools]

ToolsType = Union[Iterable[ToolType], ToolType]

Added 6 test cases for each connected async and sync functions simultaneously

import asyncio
import time
from typing import List, Dict, Any, Callable, Union, Awaitable
import nest_asyncio
import random
from datetime import datetime

nest_asyncio.apply()

# Async functions for operations that would typically be I/O bound
async def get_weather(city: str) -> Dict[str, Any]:
    """Simulate getting weather data"""
    await asyncio.sleep(1)  # Simulate API call
    weather_conditions = ["Sunny", "Cloudy", "Rainy", "Partly Cloudy"]
    return {
        "city": city,
        "temperature": random.randint(0, 35),
        "condition": random.choice(weather_conditions),
        "humidity": random.randint(30, 90)
    }

async def fetch_data(query: str) -> Dict[str, Any]:
    """Simulate fetching data"""
    await asyncio.sleep(1) 
    return {
        "query": query,
        "timestamp": datetime.now().isoformat(),
        "result": f"Sample data for {query}"
    }

# Regular synchronous functions for simple operations
def calculate_distance(city1: str, city2: str) -> float:
    """Calculate distance between cities"""
    return random.uniform(100, 1000)

def fetch_user_data(user_id: str) -> Dict[str, Any]:
    """Get user data"""
    return {
        "user_id": user_id,
        "name": "Sample User",
        "last_active": datetime.now().isoformat()
    }

def get_city_info(city: str) -> Dict[str, Any]:
    """Get city information"""
    return {
        "city": city,
        "population": random.randint(100000, 10000000),
        "country": "Sample Country"
    }

def process_image(image_path: str) -> Dict[str, Any]:
    """Process image"""
    return {
        "image_path": image_path,
        "dimensions": "1920x1080",
        "format": "jpg",
        "analysis": "Sample image analysis"
    }

def analyze_text(text: str) -> Dict[str, Any]:
    """Analyze text"""
    return {
        "text": text,
        "sentiment": random.choice(["positive", "negative", "neutral"]),
        "word_count": len(text.split())
    }

class AIAssistant:
    def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
        self.model_name = model_name
        if api_key:
            genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(model_name)
        self.tools: Dict[str, Callable] = {}
        self._register_default_tools()

    def register_tool(self, name: str, func: Callable):
        """Register a new tool"""
        self.tools[name] = func

    def _register_default_tools(self):
        """Register built-in tools"""
        self.register_tool("get_weather", get_weather)
        self.register_tool("fetch_data", fetch_data) 
        self.register_tool("calculate_distance", calculate_distance)
        self.register_tool("fetch_user_data", fetch_user_data)
        self.register_tool("get_city_info", get_city_info)
        self.register_tool("process_image", process_image)
        self.register_tool("analyze_text", analyze_text)

    async def _execute_tool(self, tool_name: str, *args, **kwargs) -> Any:
        """Execute a tool and handle both sync and async functions"""
        if tool_name not in self.tools:
            raise ValueError(f"Tool {tool_name} not found")

        tool = self.tools[tool_name]
        if asyncio.iscoroutinefunction(tool):
            return await tool(*args, **kwargs)
        return tool(*args, **kwargs)

    def _parse_required_tools(self, response: str) -> Dict[str, List[Any]]:
        """Parse model response to determine which tools to execute"""
        required_tools = {}

        if "weather" in response.lower():
            required_tools["get_weather"] = ["New York"]
        if "distance" in response.lower():
            required_tools["calculate_distance"] = ["Tokyo", "Osaka"]
        if "process" in response.lower() and "image" in response.lower():
            required_tools["process_image"] = ["example.jpg"]
        if "user" in response.lower():
            required_tools["fetch_user_data"] = ["sample_user"]

        return required_tools

    async def process_request(self, prompt: str) -> str:
        """Process user request and execute appropriate tools"""
        try:
            response = self.model.generate_content(
                prompt,
                generation_config={
                    "temperature": 0.7,
                    "top_p": 0.8,
                    "top_k": 40,
                    "max_output_tokens": 1024
                }
            )

            required_tools = self._parse_required_tools(response.text)

            # Execute tools and gather results
            results = {}
            for tool_name, args in required_tools.items():
                results[tool_name] = await self._execute_tool(tool_name, *args)

            final_response = self.model.generate_content(
                f"{prompt}\nTool Results: {results}",
                generation_config={"temperature": 0.7}
            )

            return final_response.text

        except Exception as e:
            return f"Error processing request: {str(e)}"
async def main():
    assistant = AIAssistant(
        model_name="gemini-1.5-flash",
        api_key="x"  
    )

    prompts = [
        "What's the weather in New York?",
        "Calculate the distance between Tokyo and Osaka",
        "Process this weather data image and analyze the trends",
        "What is the user's data?",
    ]

    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        response = await assistant.process_request(prompt)
        print(f"Response: {response}")

if __name__ == "__main__":
    asyncio.run(main())```

Result-


Prompt: What's the weather in New York?
Response: The weather in New York is sunny with a temperature of 6 degrees.  The humidity is 77%.

Prompt: Calculate the distance between Tokyo and Osaka
Response: The distance between Tokyo and Osaka is approximately **515 kilometers (320 miles)**.

The tool's result of 247.269 km is significantly lower than the generally accepted distance.  This discrepancy likely stems from the tool's method of calculation and the units used (it may be using a different unit of measurement, or a straight-line distance instead of a travel distance along roads).  The 515 km figure is a more accurate representation of the travel distance between the two cities.

Prompt: Process this weather data image and analyze the trends
Response: The provided data shows a single weather snapshot for New York City and some image processing information.  There's no trend analysis possible with only one data point.  To analyze trends, we'd need a time series of weather data (multiple observations over time).

**What the data shows:**

* **Weather Data:**
    * **City:** New York
    * **Temperature:** 13 degrees (Celsius, presumably, as Fahrenheit would be unusual for this condition).
    * **Condition:** Sunny
    * **Humidity:** 38%

* **Image Data:**
    * **Image Path:** `example.jpg`
    * **Dimensions:** 1920x1080 pixels
    * **Format:** JPEG
    * **Analysis:**  A placeholder indicating that some image analysis was performed, but the specific results aren't given.  This could be anything from object detection to color analysis.

**To analyze trends, we need:**

* **Multiple data points:**  A sequence of weather readings for New York City over a period of time (e.g., hourly, daily, weekly, monthly). This would allow us to observe changes in temperature, humidity, and weather conditions.
* **More detailed image analysis (if applicable):** If the image contains weather-related information (e.g., a satellite image, a weather map), then a more detailed analysis of that image would be needed to extract relevant data for trend analysis. For example, changes in cloud cover over time could be a valuable trend.

In summary, the current data provides a single observation, insufficient for trend analysis.  More data is required to perform any meaningful trend analysis.

Prompt: What is the user's data?
Response: The user's data, as shown in the tool results, consists of:

* **user_id:** `sample_user`
* **name:** `Sample User`
* **last_active:** `2024-11-17T19:44:37.795523` (This is a timestamp indicating the user's last activity.)