stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy.ai
MIT License
18.68k stars 1.43k forks source link

How to serve a dspy model #454

Open ctyler9 opened 8 months ago

ctyler9 commented 8 months ago

Have a simple model all configured but cannot get it to work on a Flask/FastAPI server. I have looked at similar issues brought up but nothing conclusive on if the problem was addressed. At bottom issue seemed to be something with threading. Can someone give me an example of how to do a simple forward pass with a Flask server?

JamesScharf commented 8 months ago

Are you able to specify your error? I had an issue with Rlock/cloudpickle a few weeks ago and was able to avoid the bug by rewriting the module's deep_copy method.

ctyler9 commented 8 months ago
from flask import Flask, request
from functools import lru_cache
import os
from dotenv import load_dotenv

load_dotenv()

app = Flask(__name__)
counter = {"api" : 0}

# RAG IMPORT 
import dspy

# set up llm and retrieval model
ollama_model = dspy.OllamaLocal(model="mistral:7b", max_tokens=500)
colbertv2 = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.settings.configure(lm=ollama_model, rm=colbertv2)

class GenerateAnswer(dspy.Signature):
    """Answer questions as a TA giving hints""" 

    context = dspy.InputField(desc=" ... ")
    question = dspy.InputField()
    answer = dspy.OutputField(desc=" ... ")

class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)

rag = RAG()

@lru_cache(maxsize=1000000)
def api_search_query(query):
    pred = rag(query)

    return {"query": query, "answer": pred.answer, "context": pred.context}

@app.route("/api/search", methods=["GET"])
def api_search():
    if request.method == "GET":
        counter["api"] += 1
        print("API request count:", counter["api"])
        return api_search_query(request.args.get("query"))
    else:
        return ('', 405)

if __name__ == "__main__":
    app.run("0.0.0.0", port=int(os.getenv("PORT")))
ctyler9 commented 8 months ago

as simple as this, flask server just hangs won't launch

ctyler9 commented 8 months ago

@JamesScharf does this help specify anything/was it similar to your error?