run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36k stars 5.12k forks source link

[Bug]: Pydantic issue with Gemini LLM #15618

Closed Shiva4113 closed 1 month ago

Shiva4113 commented 1 month ago

Bug Description

I was running the example code from the official documentation and I was getting a Validation error.

from llama_index.llms.gemini import Gemini

resp = Gemini().complete("Write a poem about a magic backpack")
print(resp)

ValidationError: 1 validation error for Gemini model Input should be a valid string [type=string_type, input_value=genai.GenerativeModel( ...stem_instruction=None, ), input_type=GenerativeModel] For further information visit https://errors.pydantic.dev/2.8/v/string_type

Version

0.11.1

Steps to Reproduce

You need to have a Gemini api key This was my code:

from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
import os 
from dotenv import load_dotenv
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.core import Settings
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.llms.gemini import Gemini
import google.generativeai as genai

load_dotenv()

GOOGLE_API_KEY = os.getenv('GEMINI_API_KEY')
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

resp = Gemini().complete("Write a poem about a magic backpack")
print(resp)

Relevant Logs/Tracbacks

{
    "name": "ValidationError",
    "message": "1 validation error for Gemini
model
  Input should be a valid string [type=string_type, input_value=genai.GenerativeModel(
  ...stem_instruction=None,
), input_type=GenerativeModel]
    For further information visit https://errors.pydantic.dev/2.8/v/string_type",
    "stack": "---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
Cell In[27], line 3
      1 from llama_index.llms.gemini import Gemini
----> 3 resp = Gemini().complete(\"Write a poem about a magic backpack\")
      4 print(resp)

File c:\\venv\\Lib\\site-packages\\llama_index\\llms\\gemini\\base.py:161, in Gemini.__init__(self, api_key, model, temperature, max_tokens, generation_config, safety_settings, callback_manager, api_base, transport, model_name, default_headers, **generate_kwargs)
    158 else:
    159     max_tokens = min(max_tokens, model_meta.output_token_limit)
--> 161 super().__init__(
    162     model=model,
    163     temperature=temperature,
    164     max_tokens=max_tokens,
    165     generate_kwargs=generate_kwargs,
    166     callback_manager=callback_manager,
    167 )
    169 self._model_meta = model_meta
    170 self._model = model

File c:\\venv\\Lib\\site-packages\\pydantic\\main.py:193, in BaseModel.__init__(self, **data)
    191 # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
    192 __tracebackhide__ = True
--> 193 self.__pydantic_validator__.validate_python(data, self_instance=self)

ValidationError: 1 validation error for Gemini
model
  Input should be a valid string [type=string_type, input_value=genai.GenerativeModel(
  ...stem_instruction=None,
), input_type=GenerativeModel]
    For further information visit https://errors.pydantic.dev/2.8/v/string_type"
}
dosubot[bot] commented 1 month ago

I found a similar issue that was resolved recently: [Bug]: Gemini not working with llama-index. This issue was closed by a pull request that fixed a Gemini argument error [1].

To resolve the Pydantic validation error for the Gemini model in your code, ensure that all fields in the Gemini class are correctly defined and initialized. Specifically, the PrivateAttr fields _model and _model_meta should be initialized properly within the __init__ method.

Here is the corrected code snippet:

from pydantic import BaseModel, Field, PrivateAttr
import os
import typing
import warnings
from typing import Any, Dict, Optional, Sequence
from llama_index.core.callbacks import CallbackManager
from llama_index.core.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE

if typing.TYPE_CHECKING:
    import google.generativeai as genai

GEMINI_MODELS = (
    "models/gemini-1.5-flash",
    "models/gemini-1.5-flash-latest",
    "models/gemini-pro",
    "models/gemini-pro-latest",
    "models/gemini-1.5-pro",
    "models/gemini-1.5-pro-latest",
    "models/gemini-1.0-pro",
    "gemini-1.5-flash",
    "gemini-1.5-flash-latest",
    "gemini-1.0-pro",
)

class Gemini(BaseModel):
    """Gemini."""

    model: str = Field(default=GEMINI_MODELS[0], description="The Gemini model to use.")
    temperature: float = Field(
        default=DEFAULT_TEMPERATURE,
        description="The temperature to use during generation.",
        gte=0.0,
        lte=1.0,
    )
    max_tokens: int = Field(
        default=DEFAULT_NUM_OUTPUTS,
        description="The number of tokens to generate.",
        gt=0,
    )
    generate_kwargs: dict = Field(
        default_factory=dict, description="Kwargs for generation."
    )

    _model: "genai.GenerativeModel" = PrivateAttr()
    _model_meta: "genai.types.Model" = PrivateAttr()

    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = GEMINI_MODELS[0],
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: Optional[int] = None,
        generation_config: Optional["genai.types.GenerationConfigDict"] = None,
        safety_settings: "genai.types.SafetySettingOptions" = None,
        callback_manager: Optional[CallbackManager] = None,
        api_base: Optional[str] = None,
        transport: Optional[str] = None,
        model_name: Optional[str] = None,
        default_headers: Optional[Dict[str, str]] = None,
        **generate_kwargs: Any,
    ):
        """Creates a new Gemini model interface."""
        try:
            import google.generativeai as genai
        except ImportError:
            raise ValueError(
                "Gemini is not installed. Please install it with "
                "`pip install 'google-generativeai>=0.3.0'`."
            )
        if model_name is not None:
            warnings.warn(
                "model_name is deprecated, please use model instead",
                DeprecationWarning,
            )
            model = model_name

        config_params: Dict[str, Any] = {
            "api_key": api_key or os.getenv("GOOGLE_API_KEY"),
        }
        if api_base:
            config_params["client_options"] = {"api_endpoint": api_base}
        if transport:
            config_params["transport"] = transport
        if default_headers:
            default_metadata: Sequence[Dict[str, str]] = []
            for key, value in default_headers.items():
                default_metadata.append((key, value))
            config_params["default_metadata"] = default_metadata

        genai.configure(**config_params)

        base_gen_config = generation_config if generation_config else {}
        final_gen_config = {"temperature": temperature, **base_gen_config}

        self._model = genai.GenerativeModel(
            model_name=model,
            generation_config=final_gen_config,
            safety_settings=safety_settings,
        )

        self._model_meta = genai.get_model(model)

        supported_methods = self._model_meta.supported_generation_methods
        if "generateContent" not in supported_methods:
            raise ValueError(
                f"Model {model} does not support content generation, only "
                f"{supported_methods}."
            )

        if not max_tokens:
            max_tokens = self._model_meta.output_token_limit
        else:
            max_tokens = min(max_tokens, self._model_meta.output_token_limit)

        super().__init__(
            model=model,
            temperature=temperature,
            max_tokens=max_tokens,
            generate_kwargs=generate_kwargs,
            callback_manager=callback_manager,
        )

Ensure that the PrivateAttr fields _model and _model_meta are properly initialized within the __init__ method to avoid validation errors. Additionally, verify that the model name provided is valid and supported by the Gemini API [2][3][4].

To continue talking to Dosu, mention @dosu.

nerdai commented 1 month ago

Thanks @Shiva4113 #15620 should've fixed this bug.