stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
18.23k stars 1.39k forks source link

Error in trying to save compiled dspy program using optimized_program.save(YOUR_SAVE_PATH) #1386

Closed vikrantrathore closed 2 months ago

vikrantrathore commented 2 months ago

Error in optimized program save. Using the given example of CompiledBaleen from https://github.com/stanfordnlp/dspy/blob/d3518d7ee9b717e816893d6dc1a875d687f88780/docs/docs/tutorials/simplified-baleen.md?plain=1#L186

Trying to save it with following code:

def validate_context_and_answer_and_hops(example, pred, trace=None):
    if not dspy.evaluate.answer_exact_match(example, pred): return False
    if not dspy.evaluate.answer_passage_match(example, pred): return False

    hops = [example.question] + [outputs.query for *_, outputs in trace if 'query' in outputs]

    if max([len(h) for h in hops]) > 100: return False
    if any(dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) for idx in range(2, len(hops))): return False

    return True

from dspy.teleprompt import BootstrapFewShot

teleprompter = BootstrapFewShot(metric=validate_context_and_answer_and_hops)
compiled_baleen = teleprompter.compile(SimplifiedBaleen(), teacher=SimplifiedBaleen(passages_per_hop=2), trainset=trainset)
compiled_baleen.save(path="multihop_qa.json")

This results in an error in dump_state in https://github.com/stanfordnlp/dspy/blob/d3518d7ee9b717e816893d6dc1a875d687f88780/dspy/primitives/module.py#L117

Following is the error message:

Bootstrapped 1 full traces after 4 examples in round 0. [('generate_query[0]', Predict(StringSignature(context, question -> rationale, query instructions='Write a simple search query that will help answer a complex question.' context = Field(annotation=str required=True json_schema_extra={'desc': 'may contain relevant facts', 'dspy_field_type': 'input', 'prefix': 'Context:'}) question = Field(annotation=str required=True json_schema_extra={'dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'}) rationale = Field(annotation=str required=True json_schema_extra={'prefix': "Reasoning: Let's think step by step in order to", 'desc': '${produce the query}. We ...', 'dspy_field_type': 'output'}) query = Field(annotation=str required=True json_schema_extra={'dspy_field_type': 'output', 'prefix': 'Query:', 'desc': '${query}'}) ))), ('generate_query[1]', Predict(StringSignature(context, question -> rationale, query instructions='Write a simple search query that will help answer a complex question.' context = Field(annotation=str required=True json_schema_extra={'desc': 'may contain relevant facts', 'dspy_field_type': 'input', 'prefix': 'Context:'}) question = Field(annotation=str required=True json_schema_extra={'dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'}) rationale = Field(annotation=str required=True json_schema_extra={'prefix': "Reasoning: Let's think step by step in order to", 'desc': '${produce the query}. We ...', 'dspy_field_type': 'output'}) query = Field(annotation=str required=True json_schema_extra={'dspy_field_type': 'output', 'prefix': 'Query:', 'desc': '${query}'}) ))), ('retrieve', <dspy.retrieve.retrieve.Retrieve object at 0x12cf48590>), ('generate_answer', Predict(StringSignature(context, question -> rationale, answer instructions='Answer questions with short factoid answers.' context = Field(annotation=str required=True json_schema_extra={'desc': 'may contain relevant facts', 'dspy_field_type': 'input', 'prefix': 'Context:'}) question = Field(annotation=str required=True json_schema_extra={'dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'}) rationale = Field(annotation=str required=True json_schema_extra={'prefix': "Reasoning: Let's think step by step in order to", 'desc': '${produce the answer}. We ...', 'dspy_field_type': 'output'}) answer = Field(annotation=str required=True json_schema_extra={'desc': 'often between 1 and 5 words', 'dspy_field_type': 'output', 'prefix': 'Answer:'}) )))]

Traceback (most recent call last): File "/Users/guestuser/projects/test_dspy/multihop_qa.py", line 98, in compiled_baleen.save(path="multihotpot_qa.json") File "/Users/guestuser/projects/test_dspy/.venv/lib/python3.12/site-packages/dspy/primitives/module.py", line 132, in save f.write(ujson.dumps(self.dump_state(save_field_meta), indent=2)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/guestuser/projects/test_dspy/.venv/lib/python3.12/site-packages/dspy/primitives/module.py", line 117, in dump_state return {name: param.dump_state(save_field_meta) for name, param in self.named_parameters()} ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Retrieve.dump_state() takes 1 positional argument but 2 were given

isaacbmiller commented 2 months ago

Related to the save meta changes @CShorten @arnavsinghvi11 #1370

module.dump_state() calls param.dump_state for all child parameters. There are a decent number of other classes that inherit from Parameter that this needs to be applied to or to check to make sure it is only on predictors.

simple repro:

import dspy

class DumpStateRepro(dspy.Module):
    def __init__(self):
        self.retrieve = dspy.Retrieve(k=1)

answer_generator = DumpStateRepro()

answer_generator.save("DumpStateRepro", save_field_meta=True)
isaacbmiller commented 2 months ago

I would be in favor of reverting #1370 until this is fixed, and so this bug doesn't accidentally make it into the next DSPy version cut.

We should also 1000% have tests for this

CShorten commented 2 months ago

Hey @isaacbmiller! Apologies to have missed this!

Added the feature flag argument to Retrieve and LangChainPredict (the other two child classes of Parameter that I originally missed) here - https://github.com/stanfordnlp/dspy/pull/1388.

Renaming now from save_field_meta to save_verbose

isaacbmiller commented 2 months ago

Should be resolved