Open josh-melton-db opened 4 months ago
TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'
File
ah we just need to add k
as a optional parameter to the DatabricksRM
forward pass since the DSPy retriever interface configures this in the kwargs (might have just been missed in the implementation). Other RM providers have this as well (reference).
Let me know if that resolves this and we can get a PR going to update that!
Running into this, any update on a merge?
Haven't gotten to this yet, will try to get there this weekend. Currently just using
from dspy.retrieve.databricks_rm import DatabricksRM
rm = DatabricksRM(
databricks_index_name=rag_config.get("vector_search_index"),
databricks_endpoint=url,
databricks_token=token,
columns=[rag_config.get("chunk_column_name")],
text_column_name=rag_config.get("chunk_column_name"),
docs_id_column_name=rag_config.get("document_source_id"),
k=rag_config.get("vector_search_parameters").get("k")
)
rm(query="Transportation and logistic issues", query_type="text")
Thanks! your workaround works for retriever but fails on modules:
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
compiled_rag = teleprompter.compile(RAG(), trainset=[trainset[0]])
ERROR:dspy.teleprompt.bootstrap
Temporary workaround in Modules
class SectionToParagraph(dspy.Module):
def __init__(self, docs_rm, iterations=3):
super().__init__()
self.docs_rm = docs_rm
self.prog = dspy.ChainOfThought(SectionToParagraphSig)
self.iterations = iterations
def get_context(self, query):
context_list = self.docs_rm(query=query, query_type="text").docs
return "\n".join(context_list)
def forward(self, section, abstract):
context = self.get_context(abstract + "\n" + section)
output = self.prog(section=section, abstract=abstract, context="")
for iteration in range(self.iterations):
output = self.prog(section=section, abstract=abstract, context=context)
return output
But obviously this is less convenient and effective, it shouldn't be too tough of a fix and I'll prioritize getting to that!
Thanks for that! I'm up and running. Much appreciated
btw we made some additional updates to DatabricksRM to better handle returning additional columns, and variable table structures: https://github.com/stanfordnlp/dspy/commit/f603036ecfe8ac36053b1d6952d9cf3907937b23
Thats handy. Through filters_json ? I assumed I could accomplish through object instantiation parameter "columns="
I think filters_json will pass filters (which removes rows) as opposed to columns (which specifies columns)
@arnavsinghvi11 adding k as a parameter throws a different error:
AttributeError: 'str' object has no attribute 'long_text'
File <command-2032780391400199>, line 161
158 dspy.settings.configure(lm=lm, rm=rm)
160 retrieve = dspy.Retrieve()
--> 161 retrieve(query_or_queries="What is Apache Spark?", query_type="text")
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dspy/retrieve/retrieve.py:40, in Retrieve.__call__(self, *args, **kwargs)
39 def __call__(self, *args, **kwargs):
---> 40 return self.forward(*args, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dspy/retrieve/retrieve.py:69, in Retrieve.forward(self, query_or_queries, k, by_prob, with_metadata, **kwargs)
67 k = k if k is not None else self.k
68 if not with_metadata:
---> 69 passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs)
70 return Prediction(passages=passages)
71 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:93, in retrieveEnsemble(queries, k, by_prob, **kwargs)
90 queries = [q for q in queries if q]
92 if len(queries) == 1:
---> 93 return retrieve(queries[0], k, **kwargs)
95 passages = {}
96 for q in queries:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:19, in retrieve(query, k, **kwargs)
15 if not isinstance(passages, Iterable):
16 # it's not an iterable yet; make it one.
17 # TODO: we should unify the type signatures of dspy.Retriever
18 passages = [passages]
---> 19 passages = [psg.long_text for psg in passages]
21 if dsp.settings.reranker:
22 passages_cs_scores = dsp.settings.reranker(query, passages)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-2dccb997-785d-48eb-b9a6-90ae447976cf/lib/python3.10/site-packages/dsp/primitives/search.py:19, in <listcomp>(.0)
15 if not isinstance(passages, Iterable):
16 # it's not an iterable yet; make it one.
17 # TODO: we should unify the type signatures of dspy.Retriever
18 passages = [passages]
---> 19 passages = [psg.long_text for psg in passages]
21 if dsp.settings.reranker:
22 passages_cs_scores = dsp.settings.reranker(query, passages)
I'm not following why there's a long_text attribute there - am I making some simple mistake in my approach?
I think the error in that code is that the String is assumed to not be Iterable. Another is that the list comprehension passages = [passages] on line 18 doesn't add the long_text parameter expected in line 19. There also seems to be a todo ensure its part of the Retriever class. Line 93 of the Retriever class seems to indicate that it "will" contain long_lext if the passage is a dict.
Following the pattern from the simple RAG Example in the docs, I've created a DatabricksRM which works when calling like
rm(query="Model serving API", query_type="text")
But when trying to use dspy.settings.configure(rm=rm) and dspy.Retrieve() like below
I get
TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'