stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
13.9k stars 1.07k forks source link

DSpy parallel processing #1126

Open anslin-raj opened 3 weeks ago

anslin-raj commented 3 weeks ago

I tried to run the DSpy module inside the thread it's not working...

I have used

dspy.Module

there I used a custom retriever module.

elastic_rm(dspy.Retrieve)

Can anyone give me a proper way for parallel processing?

tom-doerr commented 3 weeks ago

This might not be working since evaluation and some optimizers are using threading themselves. Are you trying to speed up optimization or inference?

anslin-raj commented 3 weeks ago

@tom-doerr I'm trying for inferencing, do you have any parallel inferencing code?

tom-doerr commented 3 weeks ago

Yes I do: GPakSEEacAABQvc https://x.com/tom_doerr/status/1798806436476334123

This is some separate code I use somewhere for evaluation:

            fewshot_optimizer = BootstrapFewShot(metric=great_tweet_metric, max_bootstrapped_demos=4, metric_threshold=metric_threshold)
            compile_start = time.time()

            threads = []
            for dataset_idx in range(TRAIN_SIZE):
                t = threading.Thread(target=fewshot_optimizer.compile, kwargs=dict(student=tweet_generator, teacher=teacher, trainset=[trainset[dataset_idx]]))
                threads.append(t)
                t.start()

            print("All threads have been created.")
            for t in threads:
                t.join()

            print("====== num_nesting_levels_dict:", num_nesting_levels_dict)
            print("All threads have completed.")

            tweet_generator_compiled = fewshot_optimizer.compile(student = tweet_generator, teacher = teacher, trainset=trainset)
anslin-raj commented 3 weeks ago

Thank you @tom-doerr.

I have tried TypedPredictor and TypedChainOfThought as well, but I'm facing an error, I have attached the code snippet and the error message. I'm using AsyncIO for parallel processing and FasAPI for app.

Code:

def oai_ef(text, model="text-embedding-ada-002"):
    return client.embeddings.create(input=text, model=model).data[0].embedding

class elastic_rm(dspy.Retrieve):
    def __init__(self, es_client, es_index, es_field, embedding_function, k=3):
        super().__init__()
        self.k=k
        self.es_index=es_index
        self.es_client=es_client
        self.field=es_field
        self.ef = embedding_function

    def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
        return self.ef(queries)

    def forward(self, query: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
        # Add retriver logic here
        return dspy.Prediction(passages=passages)

class RAG(dspy.Module):

    def __init__(self):
        super().__init__()
        # self.generate_answer = dspy.ChainOfThought("context, question -> answer")
        # self.generate_answer = dspy.TypedPredictor("context, question -> answer")
        self.generate_answer = dspy.TypedChainOfThought("context, question -> answer")

    def forward(self, question, rm, lm, num_passages=3):

        self.retrieve = rm(query_or_queries=question, k=num_passages)
        context = self.retrieve.passages

        prediction = self.generate_answer(question=question, context=context)
        return dspy.Prediction(answer=prediction.answer, context=context)

if __name__ =='__main__':
    rm = elastic_rm(es_client_01, "embeddings_index", "embedding_field", embedding_function=oai_ef)
    lm = dspy.OpenAI(model="gpt-4o", max_tokens=4000, api_key=OPENAI_API_KEY)
    dspy.settings.configure(lm=lm, rm=rm)
    qa = RAG()
    prompt = "Qustion"
    response = qa(prompt, rm, lm, num_passages=50)

Error:

Traceback (most recent call last):
  File "D:\app\worker\llm_worker.py", line 174, in process_request
    response = qa(prompt, rm, lm, num_passages=50)
  File "D:\app\venv\lib\site-packages\dspy\primitives\program.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "D:\app\worker\llm_worker.py", line 152, in forward
    prediction = self.generate_answer(question=question, context=context)
  File "D:\app\venv\lib\site-packages\dspy\primitives\program.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "D:\app\venv\lib\site-packages\dspy\functional\functional.py", line 295, in forward
    result = self.predictor(**modified_kwargs, new_signature=signature)
  File "D:\app\venv\lib\site-packages\dspy\predict\predict.py", line 61, in __call__
    return self.forward(**kwargs)
  File "D:\app\venv\lib\site-packages\dspy\predict\predict.py", line 111, in forward
    x, C = dsp.generate(template, **config)(x, stage=self.stage)
  File "D:\app\venv\lib\site-packages\dsp\primitives\predict.py", line 78, in do_generate
    completions: list[dict[str, Any]] = generator(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 178, in __call__
    response = self.request(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\backoff\_sync.py", line 105, in retry
    ret = target(*args, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 144, in request
    return self.basic_request(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 116, in basic_request
    kwargs = {"stringify_request": json.dumps(kwargs)}
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\__init__.py", line 231, in dumps
    return _default_encoder.encode(obj)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type GPT3 is not JSON serializable
tom-doerr commented 3 weeks ago

Is this related to parallelism? Can't see any code related to that

anslin-raj commented 3 weeks ago

@tom-doerr, Sorry for the inconvenience, there is the updated code with parallel processing, as of now I've not developed the FastAPI part.

Code:

async def process_request(es: Elasticsearch, prompt: str):
    rm = elastic_rm(es_client_01, "embeddings_index", "embedding_field", embedding_function=oai_ef)
    lm = dspy.OpenAI(model="gpt-4o", max_tokens=4000, api_key=OPENAI_API_KEY)
    dspy.settings.configure(lm=lm, rm=rm)
    qa = RAG()
    response = qa(prompt, rm, lm, num_passages=50)

async def process_requests():
    while True:
        new_requests = await retrieve_new_requests()
        tasks = [asyncio.create_task(process_request(es, request)) for request in new_requests]

        for task in asyncio.as_completed(tasks):
            try:
                result = await task
                print(f"Request processed successfully")
            except Exception as e:
                print(f"Error processing request: {e}")
                exit(0)

        await asyncio.sleep(10)

async def main():
    es = Elasticsearch(ELASTICSEARCH_HOST)
    await process_requests(es)

if __name__ == "__main__":
    asyncio.run(main())
tom-doerr commented 3 weeks ago

Could you just switch to a process-based worker model? That should still give you parallelism without needing to serialize the GPT3 instance.

anslin-raj commented 3 weeks ago

@tom-doerr I'm not aware of the process-based worker model. Do you have any sample code for the process-based worker model, could you please share it?

tom-doerr commented 3 weeks ago

As far as I know, it works using uvicorn or guvicorn

anslin-raj commented 3 weeks ago

Thank you @tom-doerr, I understand, but I had the plan to manage OpenAI calls in a centralized place when we tried more requests and more data we may face the token limit and rate limit exceptions. So I selected this one, in this structure we have full control over the OpenAI calls and customization on data retrieval.

Does DSpy automatically handle these exceptions?

Do you know any other way to solve this?

tom-doerr commented 3 weeks ago

Not really sure why having multiple instances would trigger token or rate limits faster or how having it centralized helps in data retrieval. You could try to make it serializable, not sure however how feasible this is. Other ideas:

anslin-raj commented 3 weeks ago

No @tom-doerr, multiple instances, and single instances trigger the exception based on the count and size of the request, if we handle all requests in a single instance it could be easier to avoid the occurrence of those.