google-gemini / generative-ai-python

The official Python library for the Google Gemini API
https://pypi.org/project/google-generativeai/
Apache License 2.0
1.46k stars 288 forks source link

Uri parameter cannot be overridden when sending via rest transport #319

Open karimmohraz opened 5 months ago

karimmohraz commented 5 months ago

Description of the bug:

We want to override the standard google uri part of the model endpoint. Although some parameters e.g., api endpoint (host) can be overridden the uri parameter is hard coded.

Context: we have our own AI hub where we count tokens for a set of models before forwarding the request to the model hosted by Google/ Azure-OpenAI, etc. E.g., in case of OpenAI we can override the parameters like endpoint and api_key. But this is not the case for the Google generative SDK.

Steps to reproduce

Code Example

        kwargs = dict(model_name='gemini-1.0-pro')
        model = MyGenerativeModel(proxy_client=self.proxy_client, **kwargs)
        content = self._get_test_messages()
        model_response = model.generate_content(content)

        ## in MyGenerativeModel subclass we do the following:
        super().__init__(**kwargs)
        self._client = GenerativeServiceClient(transport=transport, client_options={'api_endpoint': 'myhost'})

Actual vs expected behavior:

However, in venv/lib/python3.9/site-packages/google/ai/generativelanguage_v1beta/services/generative_service/transports/rest.py:810 uri seems to be hard coded, whereas host can be overridden.

We want to override with this uri: "models/gemini-1.0-pro:generateContent" (i.e. remove "/v1beta/" prefix)

Also a interceptor did not help as request is immutable:

class MyCustomGenerativeServiceInterceptor(GenerativeServiceRestInterceptor):
    def pre_generate_content(self, request, metadata):
        # logging.log(f"Received request: {request}")
        metadata.append(("uri", "models/gemini-1.0-pro:generateContent"))
        return request, metadata

Plz provide an option to override the uri or let us know if there is another way. Thanks!

Any other information you'd like to share?

Environment details OS type and version: Mac OS latest Python version: 3.9 pip version: 24.0 google-cloud-aiplatform version: Version: 1.47.0

MarkDaoust commented 4 months ago

Okay, it's python so there's always a workaround. Let's see what we can do.

First off, everything in google/ai/generativelanguage is generated from the proto files that define the API. So we can't easily send update those files to work differently.

rest.py:810

The interceptor class doesn't help because it doesn't even get access to the URI. The URI is in http_options, which is not passes in. Metadata is extra headers for the request, and request at that point is just the request body.

We either need to edit the protos and generate a version of the library, or find some other way to patch in.

Regenerating from protos

One nice effect of this approach is that you can generate the clients for other languages too.

The lines in the protos that set the rest uri are these ones:

https://github.com/googleapis/googleapis/blob/79c1b132c6c8220ad2a071bd2338236f15c807b6/google/ai/generativelanguage/v1beta/generative_service.proto#L45-L56

From that directory the command to generate the library is:

https://github.com/googleapis/googleapis/blob/79c1b132c6c8220ad2a071bd2338236f15c807b6/google/ai/generativelanguage/v1beta/BUILD.bazel#L226

bazelisk build //google/ai/generativelanguage/v1beta:ai-generativelanguage-v1beta-py

That generates a .tar.gz file that you can pip install. You may want to adjust it's version number.

Patching

I'm not sure what the best approach is here. I normally work above the google.ai.generativelanguage package, not inside of it.

I'd start by forking GenerativeServiceRestTransport to a separate module. Then replace all the URIs. Then how to patch it in?

Where does GenerativeServiceRestTransport get used? -> It looks like generative_service/client.py and generative_service/transports/__init__.py. Both put it in a _transport_registry['rest']

It's not clear that the one in init does anything.

So once you have the replacement version of GenerativeServiceRestTransport you might be able to just say:

import my_rest_transport
import google.ai.generativelanguage_v1beta.generative_service.client as gs_client

gs_client.GenerativeServiceClientMeta._transport_registry['my_rest'] = my_rest_transport.MyGenerativeServiceRestTransport

client = import google.ai.generativelanguage.GenerativeServiceClient(transport='my_rest')
MarkDaoust commented 4 months ago

But I'm not sure I caught the question about API keys.

karimmohraz commented 4 months ago

Hello Marc, thanks for your reply. Currently I have overridden the class GenerativeServiceRestTransport. Inside this class there is the call method. Here I have adapted the hard coded http_options http_options: List[Dict[str, str]] = [ { "method": "post", "uri": "{model=models/*}:generateContent", "body": "*", }, { "method": "post", "uri": "{model=tunedModels/*}:generateContent", "body": "*", }, ] So it would be great if it would be possible to parametrize the http options. then I would not need to override the class and internal methods like call

Other than the suffix, i.e. the inference method in the uri: "generateContent" I successfully passed the endpoint and "rest" transport parameter. And in my override generateContent method I could pass "request_options" which contain a bearer token as a request header. So everything worked via parametrization except the hard coded uri part.