stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy.ai
MIT License
18.76k stars 1.44k forks source link

Fail to load state of compiled model #135

Closed Puzer closed 1 year ago

Puzer commented 1 year ago

When I try to load state of compiled model and call it I get exception here: https://github.com/stanfordnlp/dspy/blob/51e511115c961329223070cc50f5a5d0d61db58d/dsp/templates/template_v2.py#L196

AttributeError: 'dict' object has no attribute 'augmented'

demo is a dictionary, but seems that it should be dspy.Example

I've tried to bypass that by loading state manually and converting state['generate_answer']['demos'] into dspy.Example, and it works, but instead of answer I've got rationale instead of answer itself. Seems that it's the case for ChainOfThought and I think it should work fine for just regular Predictor.

Example to reproduce:

class MathSolverSignature(dspy.Signature):
    """Perform math operation"""
    query = dspy.InputField(desc="Mathematical expression")
    answer = dspy.OutputField(desc="Answer to the mathematical expression")

class MathSolverModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_answer = dspy.ChainOfThought(MathSolverSignature)

    def forward(self, **kwargs):
        prediction = self.generate_answer(**kwargs)
        return prediction

dataset = [{"query":"2+2=", "answer":"4"}, {"query":"5*5=", "answer":"25"}, {"query":"0+1=", "answer":"1"}]
dataset = list(map(lambda x: dspy.Example(**x).with_inputs("query"), dataset))

eval_func = lambda a,b,_: a.answer == b.answer

model = MathSolverModule()
teleprompter = BootstrapFewShot(metric=eval_func, max_bootstrapped_demos=1, max_labeled_demos=2)
compiled_model = teleprompter.compile(model, trainset=dataset)

test_example_results = compiled_model(query="2+2=") # OK
compiled_model.save("trained_model.json")

loaded_model = MathSolverModule()
loaded_model.load("trained_model.json")
answer = loaded_model(query="2+2=") # (!) FAIL
arnavsinghvi11 commented 1 year ago

Hi

Thank you for providing this example to reproduce! Couple of suggestions based on your use case:

Regarding "instead of answer I've got rationale instead of answer itself," I notice the MathSolverModule forward pass returns the predictor call itself (self.generate_answer(**kwargs)) which will indeed print out the Prediction object containing both rationale and answer. However, to meet your case of returning only answer, you would want to do return dspy.Prediction(answer=prediction.answer). The RAG() module defined in the intro notebook is a great example for showing this!

Regarding the AttributeError you encounter, this was likely due to some internal conversion error between how the module is being serialized and somehow not retaining the Example object while being deserialized. A quick fix for this is including this snippet to override the update method of dsp.Example (dspy/dsp/primitives/demonstrate.py class Example) to fix the data format inconsistencies after the module is loaded:

def update(self, *args, **kwargs)

    super().update(*args, **kwargs)

    for key, value in self.items():

        if isinstance(value, dict) and not isinstance(value, Example):

            self[key] = Example(value)

        elif isinstance(value, list):

            self[key] = [Example(item) if isinstance(item, dict) and not isinstance(item, Example) else item for item in value]

However, this is something we will look further into as it seems to be a deeper bug in the loading of the saved module metadata correctly. Thank you for raising this error! Please let me know if you have any additional questions.

okhat commented 1 year ago

This is now fixed on main, @Puzer!

Thanks for reporting. (The fix is on the main branch. It will be part of the next pip release too.)