googleapis / python-aiplatform

A Python SDK for Vertex AI, a fully managed, end-to-end platform for data science and machine learning.
Apache License 2.0
615 stars 328 forks source link

Palm2 TextGeneration parameter logit_bias not working #4111

Closed alexandrasouly closed 1 month ago

alexandrasouly commented 1 month ago

I'm running into the same issue as https://stackoverflow.com/questions/77805968/how-to-use-logit-bias-parameter-of-model-predict-in-vertexai-python-sdk when trying to use the logit_bias parameter in the Palm2 textegeneration API as documented in https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.language_models.TextGenerationModel#vertexai_language_models_TextGenerationModel_predict

Steps to reproduce

The following code using the documented logit_bias parameter results in an errror:

  model = TextGenerationModel.from_pretrained("text-bison@002")
  responses = model.predict("hi?",logit_bias={10:-13.})

Stack trace

Traceback (most recent call last):                                                                                                                                                                                                    
  File "/home/ubuntu/code/folder/delete.py", line 42, in <module>                                                                                                                                                                     
    generate()                                                                                                                                                                                                                        
  File "/home/ubuntu/code/folder/delete.py", line 13, in generate                                                                                                                                                                     
    responses = model.predict("hi?",logit_bias={10:-13.})                                                                                                                                                                             
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                             
  File "/home/ubuntu/.local/lib/python3.11/site-packages/vertexai/language_models/_language_models.py", line 1417, in predict                                                                                                         
    prediction_response = self._endpoint.predict(                                                                                                                                                                                     
                          ^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                     
  File "/home/ubuntu/.local/lib/python3.11/site-packages/google/cloud/aiplatform/models.py", line 2105, in predict                                                                                                                    
    prediction_response = self._prediction_client.predict(                                                                                                                                                                            
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                            
  File "/home/ubuntu/.local/lib/python3.11/site-packages/google/cloud/aiplatform_v1/services/prediction_service/client.py", line 835, in predict                                                                                      
    request.parameters = parameters                                                                                                                                                                                                   
    ^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/message.py", line 935, in __setattr__                                                                                                                                  
    pb_value = marshal.to_proto(pb_type, value)                                                                                                                                                                                       
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                       
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/marshal.py", line 235, in to_proto                                                                                                                             
    pb_value = self.get_rule(proto_type=proto_type).to_proto(value)                                                                                                                                                                   
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                   
  File "/home/ubuntu/.local/lib/python3.11/site-packages/google/cloud/aiplatform/utils/enhanced_library/_decorators.py", line 33, in to_proto                                                                                         
    return super().to_proto(value)                                                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                    
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/rules/struct.py", line 82, in to_proto                                                                                                                         
    struct_value=self._marshal.to_proto(struct_pb2.Struct, value),                                                                                                                                                                    
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                     
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/marshal.py", line 235, in to_proto
    pb_value = self.get_rule(proto_type=proto_type).to_proto(value)                                                                                                                                                          [15/1844]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                   
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/rules/struct.py", line 139, in to_proto                                                                                                                        
    fields={                                                                                                                                                                                                                          
           ^                                                                                                                                                                                                                          
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/rules/struct.py", line 140, in <dictcomp>                                                                                                                      
    k: self._marshal.to_proto(struct_pb2.Value, v) for k, v in value.items()                                                                                                                                                          
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                    
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/marshal.py", line 235, in to_proto                                                                                                                             
    pb_value = self.get_rule(proto_type=proto_type).to_proto(value)                                                                                                                                                                   
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                   
  File "/home/ubuntu/.local/lib/python3.11/site-packages/google/cloud/aiplatform/utils/enhanced_library/_decorators.py", line 33, in to_proto                                                                                         
    return super().to_proto(value)                                                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                    
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/rules/struct.py", line 82, in to_proto                                                                                                                         
    struct_value=self._marshal.to_proto(struct_pb2.Struct, value),                                                                                                                                                                    
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                  
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/marshal.py", line 235, in to_proto                                                                                                                             
    pb_value = self.get_rule(proto_type=proto_type).to_proto(value)                                                
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                
  File "/home/ubuntu/.local/lib/python3.11/site-packages/proto/marshal/rules/struct.py", line 138, in to_proto                                                                                                                        
    answer = struct_pb2.Struct(  
sasha-gitg commented 1 month ago

The type annotation for the key is incorrect. It should be a string instead of an int.

responses = model.predict("hi?",logit_bias={'10':-13.})
sasha-gitg commented 1 month ago

The type annotations were fixed in 1.60: https://github.com/googleapis/python-aiplatform/releases/tag/v1.60.0