stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
16.81k stars 1.3k forks source link

max_tokens error for google's models #1311

Open abdullahkarakus opened 1 month ago

abdullahkarakus commented 1 month ago

dspy.Google does not work properly as it always results in KeyError: 'max_tokens'. The problem stems from line 95 in dsp/primitives/predict.py

 max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"])

This creates a key error since google's models have max_output_tokens property instead of max_tokens.

grant-d commented 1 month ago

Here's a snippet that I use to work around the problem. hth, but obviously a fix would be better.

import dspy
import dspy.utils

llm: dspy.utils.LM = dspy.Google(model="models/gemini-1.5-flash")
....

# Configure temperature & tokens
max_tokens = 100
temperature = 0.5

is_google: bool = dspy.settings.lm.__class__.__name__ == "Google"
max_tokens_key: str = 'max_output_tokens' if is_google else 'max_tokens'
kwargs = { max_tokens_key: max_tokens, "temperature": temperature }
d3banjan commented 1 month ago

Monkey patching the settings and client at runtime. I used the try catch block so that when the upstream is fixed, the branch with the fix will not be used. I also needed to pass request_options to the client during generate -- patching the dspy.Google client allows you to pass request_options during generation - retry is given as an example.

import dspy
from dsp.utils import Settings

class GoogleSettings(Settings):
    def __getattr__(self, name):
        translations = {"Google": {"max_tokens": "max_output_tokens"}}
        try:
            return super().__getattr__(
                name
            )  # Will follow this when dspy fixes the issue
        except KeyError:
            model_name = super().__getattr__("lm").__class__.__name__
            return super().__getattr__(translations[model_name][name])

dspy.settings = GoogleSettings()

class GoogleRetry(dspy.Google):
    def basic_request(self, prompt: str, **kwargs):
        request_options = kwargs.pop("request_options", None)
        raw_kwargs = kwargs
        kwargs = {
            **self.kwargs,
            **kwargs,
        }

        # Google disallows "n" arguments
        n = kwargs.pop("n", None)
        if n is not None and n > 1 and kwargs["temperature"] == 0.0:
            kwargs["temperature"] = 0.7

        response = self.llm.generate_content(
            prompt,
            generation_config=kwargs,
            **({} if request_options is None else {"request_options": request_options}),
        )

        history = {
            "prompt": prompt,
            "response": [response],
            "kwargs": kwargs,
            "raw_kwargs": raw_kwargs,
        }
        self.history.append(history)

        return response

def setup_lm(config):
    _lm = GoogleRetry(
        model="gemini-1.5-flash",
        api_key=config["GOOGLE"]["GOOGLE_API_KEY"],
        safety_settings={},
    )
    dspy.settings.configure(lm=_lm, inherit_config=True)
    return _lm