stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
14.92k stars 1.15k forks source link

Gemini throws 400 error while compiling signature #1087

Open felixgao opened 1 month ago

felixgao commented 1 month ago

The code throws google.api_core.exceptions.InvalidArgument: 400 Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again. error in my compile code.

class GenerateExtraction(dspy.Signature):
    """Extracting requested information from a document."""

    context = dspy.InputField(desc="The document text.")
    task = dspy.InputField(desc="The task to extract necessary information from document. Use the mapping values for the keys.")
    answer = dspy.OutputField(
        desc=f"""A list of expected key-value pairs, if a key doesn't have a value return N/A. 
        IMPORTANT!!! The list must be semi-colon separated.  
        Do not include any other information.""",
    )
class SimpleDocumentTextQA(dspy.Module):
    def __init__(self, signature: dspy.Signature | None = None):
        super().__init__()
        if signature is None:
            signature = GenerateExtraction
        self.predictor = dspy.Predict(signature)

    def forward(self, context, question):
        pred = self.predictor(context=context, task=question)
        return dspy.Prediction(pred)
module = SimpleDocumentTextQA()
teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=2)
    config = dict(num_threads=thread_count, display_progress=True)
     optmized = teleprompter.compile(
          module,
          trainset=train,
          eval_kwargs=config
      )

The stack trace

Average Metric: 0.7307692307692308 / 9  (8.1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:26<00:00,  2.96s/it]
Traceback (most recent call last):
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/grpc_helpers.py", line 76, in error_remapped_callable
    return callable_(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/grpc/_channel.py", line 1181, in __call__
    return _end_unary_response_blocking(state, call, False, None)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/grpc/_channel.py", line 1006, in _end_unary_response_blocking
    raise _InactiveRpcError(state)  # pytype: disable=not-instantiable
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
    status = StatusCode.INVALID_ARGUMENT
    details = "Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again."
    debug_error_string = "UNKNOWN:Error received from peer ipv4:142.251.215.234:443 {grpc_message:"Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again.", grpc_status:3, created_time:"2024-05-31T09:51:39.510654-07:00"}"
>

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 308, in <module>
    fire.Fire(main)
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 286, in main
    optmized = teleprompter.compile(
               ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/teleprompt/copro_optimizer.py", line 307, in compile
    instr = dspy.Predict(
            ^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/predict/predict.py", line 61, in __call__
    return self.forward(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/predict/predict.py", line 103, in forward
    x, C = dsp.generate(template, **config)(x, stage=self.stage)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/primitives/predict.py", line 77, in do_generate
    completions: list[dict[str, Any]] = generator(prompt, **kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 177, in __call__
    return self.request(prompt, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/backoff/_sync.py", line 105, in retry
    ret = target(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 168, in request
    return self.basic_request(prompt, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 126, in basic_request
    response = self.client.generate_content(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py", line 407, in generate_content
    return self._generate_content(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py", line 496, in _generate_content
    gapic_response = self._prediction_client.generate_content(request=request)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py", line 2103, in generate_content
    response = rpc(
               ^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/gapic_v1/method.py", line 131, in __call__
    return wrapped_func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/grpc_helpers.py", line 78, in error_remapped_callable
    raise exceptions.from_grpc_error(exc) from exc
google.api_core.exceptions.InvalidArgument: 400 Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again.

Any ideas on what is the cause of this?

tom-doerr commented 1 month ago

Is thread_count set to 2? Maybe an issue with multithreading/batching

mikeedjones commented 1 month ago

breadth is passed to the model as n - so you need to set breadth to 1 if you're using gemini.

felixgao commented 1 month ago

I set the thread_count to 1 and still failed. Here is the updated code

def signature_optimization(module, train, metric_fn, thread_count:int=1) -> SimpleDocumentTextQA:
    from dspy.teleprompt import COPRO
    teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=1)
    config = dict(num_threads=thread_count, display_progress=True)
    optmized = teleprompter.compile(
            module,
            trainset=train,
            eval_kwargs=config
        )
    optmized.save("optimized_signature.json")
    return SimpleDocumentTextQA(optmized)

when I set breadth to 1 the code failed with the following error

Evaluation Result (Before Optimization): 0.0
Signature Optimization...
Traceback (most recent call last):
  File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 341, in <module>
    fire.Fire(main)
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 331, in main
    optmized = signature_optimization(module, train, metric_fn, thread_count)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 263, in signature_optimization
    teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=1)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/teleprompt/copro_optimizer.py", line 69, in __init__
    raise ValueError("Breadth must be greater than 1")
ValueError: Breadth must be greater than 1

The error makes sense because the COPRO code is specifically looking for breadth to be greater than 1. Can someone help me understand what is this breadth variable. There isn't enough documentation on what the variable is controlling.

class COPRO(Teleprompter):
    def __init__(
        self,
        prompt_model=None,
        metric=None,
        breadth=10,
        depth=3,
        init_temperature=1.4,
        track_stats=False,
        **_kwargs,
    ):
        if breadth <= 1:
            raise ValueError("Breadth must be greater than 1")
        self.metric = metric
        self.breadth = breadth
        self.depth = depth
        self.init_temperature = init_temperature
        self.prompt_model = prompt_model
        self.track_stats = track_stats
tom-doerr commented 1 month ago

Just check the code and n is set to

 n=self.breadth - 1,

I would set breadth to 2, so that n becomes one. Not confident at all though that this works

mikeedjones commented 1 month ago

Later in COPRO it uses n=breath iirc. Think a problem with gemini + model-garden API.

okhat commented 1 month ago

Hmm are we seeing a collision between parameter names in DSPy and parameter names in Gemini? ...

mikeedjones commented 1 month ago

The model garden API serves multiple models with a common generation config model. For gemini [candidateCount] must be 1 - but it's being set to 2 by n