Open abdullahkarakus opened 1 month ago
Here's a snippet that I use to work around the problem. hth, but obviously a fix would be better.
import dspy
import dspy.utils
llm: dspy.utils.LM = dspy.Google(model="models/gemini-1.5-flash")
....
# Configure temperature & tokens
max_tokens = 100
temperature = 0.5
is_google: bool = dspy.settings.lm.__class__.__name__ == "Google"
max_tokens_key: str = 'max_output_tokens' if is_google else 'max_tokens'
kwargs = { max_tokens_key: max_tokens, "temperature": temperature }
Monkey patching the settings and client at runtime. I used the try catch block so that when the upstream is fixed, the branch with the fix will not be used. I also needed to pass request_options
to the client during generate -- patching the dspy.Google
client allows you to pass request_options
during generation - retry is given as an example.
import dspy
from dsp.utils import Settings
class GoogleSettings(Settings):
def __getattr__(self, name):
translations = {"Google": {"max_tokens": "max_output_tokens"}}
try:
return super().__getattr__(
name
) # Will follow this when dspy fixes the issue
except KeyError:
model_name = super().__getattr__("lm").__class__.__name__
return super().__getattr__(translations[model_name][name])
dspy.settings = GoogleSettings()
class GoogleRetry(dspy.Google):
def basic_request(self, prompt: str, **kwargs):
request_options = kwargs.pop("request_options", None)
raw_kwargs = kwargs
kwargs = {
**self.kwargs,
**kwargs,
}
# Google disallows "n" arguments
n = kwargs.pop("n", None)
if n is not None and n > 1 and kwargs["temperature"] == 0.0:
kwargs["temperature"] = 0.7
response = self.llm.generate_content(
prompt,
generation_config=kwargs,
**({} if request_options is None else {"request_options": request_options}),
)
history = {
"prompt": prompt,
"response": [response],
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return response
def setup_lm(config):
_lm = GoogleRetry(
model="gemini-1.5-flash",
api_key=config["GOOGLE"]["GOOGLE_API_KEY"],
safety_settings={},
)
dspy.settings.configure(lm=_lm, inherit_config=True)
return _lm
dspy.Google does not work properly as it always results in KeyError: 'max_tokens'. The problem stems from line 95 in dsp/primitives/predict.py
This creates a key error since google's models have max_output_tokens property instead of max_tokens.