stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
13.75k stars 1.05k forks source link

DatabricksRM retrieval using dspy.Retrieve() throws TypeError #1191

Open josh-melton-db opened 6 days ago

josh-melton-db commented 6 days ago

Following the pattern from the simple RAG Example in the docs, I've created a DatabricksRM which works when calling like rm(query="Model serving API", query_type="text")

But when trying to use dspy.settings.configure(rm=rm) and dspy.Retrieve() like below

import dspy
from dspy.retrieve.databricks_rm import DatabricksRM

token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
serving_url = url + '/serving-endpoints'

lm = dspy.Databricks(model='databricks-dbrx-instruct', model_type='chat', api_key=token, 
                     api_base=serving_url, max_tokens=1000, temperature=0.85)
teacher = dspy.Databricks(model='databricks-meta-llama-3-70b-instruct', model_type='chat', api_key=token, 
                          api_base=serving_url, max_tokens=1000, temperature=0)
rm = DatabricksRM( # This index was created using the Databricks Demo Center RAG Tutorial
    databricks_index_name="catalog.schema.databricks_documentation_vs_index",
    databricks_endpoint=url,
    databricks_token=token,
    columns=["content"],
    text_column_name="content",
    docs_id_column_name="id",
)
dspy.settings.configure(lm=lm, rm=rm)

retrieve = dspy.Retrieve()
retrieve(query_or_queries="What is Apache Spark?", query_type="text")

I get TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'

josh-melton-db commented 6 days ago

TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k' File , line 23 20 dspy.settings.configure(lm=lm, rm=rm) 22 retrieve = dspy.Retrieve() ---> 23 retrieve(query_or_queries="What is Apache Spark?", query_type="text") File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:30, in Retrieve.call(self, *args, kwargs) 29 def call(self, *args, *kwargs): ---> 30 return self.forward(args, kwargs) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:39, in Retrieve.forward(self, query_or_queries, k, kwargs) 36 # print(queries) 37 # TODO: Consider removing any quote-like markers that surround the query too. 38 k = k if k is not None else self.k ---> 39 passages = dsp.retrieveEnsemble(queries, k=k,kwargs) 40 return Prediction(passages=passages) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dsp/primitives/search.py:57, in retrieveEnsemble(queries, k, by_prob, kwargs) 54 queries = [q for q in queries if q] 56 if len(queries) == 1: ---> 57 return retrieve(queries[0], k, kwargs) 59 passages = {} 60 for q in queries: File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dsp/primitives/search.py:12, in retrieve(query, k, kwargs) 10 if not dsp.settings.rm: 11 raise AssertionError("No RM is loaded.") ---> 12 passages = dsp.settings.rm(query, k=k, kwargs) 13 if not isinstance(passages, Iterable): 14 # it's not an iterable yet; make it one. 15 # TODO: we should unify the type signatures of dspy.Retriever 16 passages = [passages] File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:30, in Retrieve.call(self, *args, kwargs) 29 def call(self, *args, *kwargs): ---> 30 return self.forward(args, kwargs)

arnavsinghvi11 commented 6 days ago

ah we just need to add k as a optional parameter to the DatabricksRM forward pass since the DSPy retriever interface configures this in the kwargs (might have just been missed in the implementation). Other RM providers have this as well (reference).

Let me know if that resolves this and we can get a PR going to update that!