stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy.ai
MIT License
18.8k stars 1.44k forks source link

DatabricksRM retrieval using dspy.Retrieve() throws TypeError #1191

Open josh-melton-db opened 4 months ago

josh-melton-db commented 4 months ago

Following the pattern from the simple RAG Example in the docs, I've created a DatabricksRM which works when calling like rm(query="Model serving API", query_type="text")

But when trying to use dspy.settings.configure(rm=rm) and dspy.Retrieve() like below

import dspy
from dspy.retrieve.databricks_rm import DatabricksRM

token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
serving_url = url + '/serving-endpoints'

lm = dspy.Databricks(model='databricks-dbrx-instruct', model_type='chat', api_key=token, 
                     api_base=serving_url, max_tokens=1000, temperature=0.85)
teacher = dspy.Databricks(model='databricks-meta-llama-3-70b-instruct', model_type='chat', api_key=token, 
                          api_base=serving_url, max_tokens=1000, temperature=0)
rm = DatabricksRM( # This index was created using the Databricks Demo Center RAG Tutorial
    databricks_index_name="catalog.schema.databricks_documentation_vs_index",
    databricks_endpoint=url,
    databricks_token=token,
    columns=["content"],
    text_column_name="content",
    docs_id_column_name="id",
)
dspy.settings.configure(lm=lm, rm=rm)

retrieve = dspy.Retrieve()
retrieve(query_or_queries="What is Apache Spark?", query_type="text")

I get TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'

josh-melton-db commented 4 months ago

TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k' File , line 23 20 dspy.settings.configure(lm=lm, rm=rm) 22 retrieve = dspy.Retrieve() ---> 23 retrieve(query_or_queries="What is Apache Spark?", query_type="text") File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:30, in Retrieve.call(self, *args, kwargs) 29 def call(self, *args, *kwargs): ---> 30 return self.forward(args, kwargs) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:39, in Retrieve.forward(self, query_or_queries, k, kwargs) 36 # print(queries) 37 # TODO: Consider removing any quote-like markers that surround the query too. 38 k = k if k is not None else self.k ---> 39 passages = dsp.retrieveEnsemble(queries, k=k,kwargs) 40 return Prediction(passages=passages) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dsp/primitives/search.py:57, in retrieveEnsemble(queries, k, by_prob, kwargs) 54 queries = [q for q in queries if q] 56 if len(queries) == 1: ---> 57 return retrieve(queries[0], k, kwargs) 59 passages = {} 60 for q in queries: File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dsp/primitives/search.py:12, in retrieve(query, k, kwargs) 10 if not dsp.settings.rm: 11 raise AssertionError("No RM is loaded.") ---> 12 passages = dsp.settings.rm(query, k=k, kwargs) 13 if not isinstance(passages, Iterable): 14 # it's not an iterable yet; make it one. 15 # TODO: we should unify the type signatures of dspy.Retriever 16 passages = [passages] File /local_disk0/.ephemeral_nfs/envs/pythonEnv-a683426f-5c6f-4490-a4e5-5a90721a88d7/lib/python3.11/site-packages/dspy/retrieve/retrieve.py:30, in Retrieve.call(self, *args, kwargs) 29 def call(self, *args, *kwargs): ---> 30 return self.forward(args, kwargs)

arnavsinghvi11 commented 4 months ago

ah we just need to add k as a optional parameter to the DatabricksRM forward pass since the DSPy retriever interface configures this in the kwargs (might have just been missed in the implementation). Other RM providers have this as well (reference).

Let me know if that resolves this and we can get a PR going to update that!

kvnlngpenn commented 3 months ago

Running into this, any update on a merge?

josh-melton-db commented 3 months ago

Haven't gotten to this yet, will try to get there this weekend. Currently just using

from dspy.retrieve.databricks_rm import DatabricksRM

rm = DatabricksRM(
    databricks_index_name=rag_config.get("vector_search_index"),
    databricks_endpoint=url,
    databricks_token=token,
    columns=[rag_config.get("chunk_column_name")],
    text_column_name=rag_config.get("chunk_column_name"),
    docs_id_column_name=rag_config.get("document_source_id"), 
    k=rag_config.get("vector_search_parameters").get("k")
)
rm(query="Transportation and logistic issues", query_type="text")
kvnlngpenn commented 3 months ago

Thanks! your workaround works for retriever but fails on modules:

teleprompter = BootstrapFewShot(metric=validate_context_and_answer) compiled_rag = teleprompter.compile(RAG(), trainset=[trainset[0]]) ERROR:dspy.teleprompt.bootstrap due to DatabricksRM.forward() got an unexpected keyword argument 'k'. [dspy.teleprompt.bootstrap] filename=bootstrap.py lineno=211

josh-melton-db commented 3 months ago

Temporary workaround in Modules

class SectionToParagraph(dspy.Module):
    def __init__(self, docs_rm, iterations=3):
        super().__init__()
        self.docs_rm = docs_rm
        self.prog = dspy.ChainOfThought(SectionToParagraphSig)
        self.iterations = iterations

    def get_context(self, query):
        context_list = self.docs_rm(query=query, query_type="text").docs
        return "\n".join(context_list)

    def forward(self, section, abstract):
        context = self.get_context(abstract + "\n" + section)
        output = self.prog(section=section, abstract=abstract, context="")
        for iteration in range(self.iterations):
            output = self.prog(section=section, abstract=abstract, context=context)
        return output

But obviously this is less convenient and effective, it shouldn't be too tough of a fix and I'll prioritize getting to that!

kvnlngpenn commented 3 months ago

Thanks for that! I'm up and running. Much appreciated

danpechi commented 3 months ago

btw we made some additional updates to DatabricksRM to better handle returning additional columns, and variable table structures: https://github.com/stanfordnlp/dspy/commit/f603036ecfe8ac36053b1d6952d9cf3907937b23

kvnlngpenn commented 3 months ago

Thats handy. Through filters_json ? I assumed I could accomplish through object instantiation parameter "columns="

josh-melton-db commented 3 months ago

I think filters_json will pass filters (which removes rows) as opposed to columns (which specifies columns)

josh-melton-db commented 3 months ago

@arnavsinghvi11 adding k as a parameter throws a different error:

AttributeError: 'str' object has no attribute 'long_text'
File <command-2032780391400199>, line 161
    158 dspy.settings.configure(lm=lm, rm=rm)
    160 retrieve = dspy.Retrieve()
--> 161 retrieve(query_or_queries="What is Apache Spark?", query_type="text")
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dspy/retrieve/retrieve.py:40, in Retrieve.__call__(self, *args, **kwargs)
     39 def __call__(self, *args, **kwargs):
---> 40     return self.forward(*args, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dspy/retrieve/retrieve.py:69, in Retrieve.forward(self, query_or_queries, k, by_prob, with_metadata, **kwargs)
     67 k = k if k is not None else self.k
     68 if not with_metadata:
---> 69     passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs)
     70     return Prediction(passages=passages)
     71 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:93, in retrieveEnsemble(queries, k, by_prob, **kwargs)
     90 queries = [q for q in queries if q]
     92 if len(queries) == 1:
---> 93     return retrieve(queries[0], k, **kwargs)
     95 passages = {}
     96 for q in queries:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:19, in retrieve(query, k, **kwargs)
     15 if not isinstance(passages, Iterable):
     16     # it's not an iterable yet; make it one.
     17     # TODO: we should unify the type signatures of dspy.Retriever
     18     passages = [passages]
---> 19 passages = [psg.long_text for psg in passages]
     21 if dsp.settings.reranker:
     22     passages_cs_scores = dsp.settings.reranker(query, passages)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:19, in <listcomp>(.0)
     15 if not isinstance(passages, Iterable):
     16     # it's not an iterable yet; make it one.
     17     # TODO: we should unify the type signatures of dspy.Retriever
     18     passages = [passages]
---> 19 passages = [psg.long_text for psg in passages]
     21 if dsp.settings.reranker:
     22     passages_cs_scores = dsp.settings.reranker(query, passages)

I'm not following why there's a long_text attribute there - am I making some simple mistake in my approach?

kvnlng commented 3 months ago

I think the error in that code is that the String is assumed to not be Iterable. Another is that the list comprehension passages = [passages] on line 18 doesn't add the long_text parameter expected in line 19. There also seems to be a todo ensure its part of the Retriever class. Line 93 of the Retriever class seems to indicate that it "will" contain long_lext if the passage is a dict.