Closed hugoferrero closed 3 months ago
I am adding the following notebook, which I believe the code example above is derived from: https://github.com/vanna-ai/vanna/blob/fb384d43a0fb50a7cbe366cd6c242d5d37b16569/notebooks/bigquery-other-llm-chromadb.ipynb
Would be great to have a notebook describing how to do this or similar :]
Perhaps you want to take a look at the recently added Ollama
implementation:
https://github.com/vanna-ai/vanna/blob/fb384d43a0fb50a7cbe366cd6c242d5d37b16569/src/vanna/ollama/__init__.py#L6
I added the code below which should give you some ideas on what is required to add support to any other LLM model.
Perhaps that is exactly what you were looking for, @hugoferrero? :]
from ..base import VannaBase
import requests
import json
import re
class Ollama(VannaBase):
def __init__(self, config=None):
if config is None or 'ollama_host' not in config:
self.host = "http://localhost:11434"
else:
self.host = config['ollama_host']
if config is None or 'model' not in config:
raise ValueError("config must contain a Ollama model")
else:
self.model = config['model']
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)
match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace('```', '')
else:
return text
def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)
# Replace "\_" with "_"
sql = sql.replace("\\_", "_")
sql = sql.replace("\\", "")
return self.extract_sql_query(sql)
def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/api/chat"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}
response = requests.post(url, json=data)
response_dict = response.json()
self.log(response.text)
return response_dict['message']['content']
Thanks for the response @andreped. I will try it, and send you feddback.
@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?
@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?
@zainhoda I am open to drafting a PR for :] I can tag you in, @hugoferrero, if you wish to test it before merging.
I made a PR https://github.com/vanna-ai/vanna/pull/264.
It is a rather simple implementation but sadly I do not have access to Google Cloud. I am therefore dependent on some of you to test it.
To install:
pip install git+https://github.com/andreped/vanna.git@bison-support
pip install chromadb google-cloud-aiplatform
Then you should be able to initialize it with a vector DB like Chroma like so:
from vertexai.language_models import ChatModel
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.palm.palm import Palm
class MyVanna(ChromaDB_VectorStore, Palm):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)
vn = MyVanna()
# do as you normally would with any other client(s)
# [...]
@andreped I've tried the code and it's not working, it gives errors like: 'MyVanna' object has no attribute 'temperature'
I then tried explicitly mentioning the parameters as below:
def submit_prompt(self, prompt, **kwargs) -> str:
temperature = 0.7
max_tokens = 500
top_p = 0.95
top_k = 40
params = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
}
response = self.client.send_message(prompt, **params)
return response.text
This gave me the error: _'ChatModel' object has no attribute 'sendmessage'
I also attempted to specify the model more explicitly:
def submit_prompt(self, prompt, **kwargs) -> str:
temperature = 0.7
max_tokens = 500
top_p = 0.95
top_k = 40
client = ChatModel("chat-bison@001")
chat = client.start_chat()
params = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
}
print(prompt)
response = chat.send_message(prompt, **params)
return response.text
This gave me this error: 400 Invalid resource field value in the request.
Please help šš¼
@andreped I've tried the code and it's not working, it gives errors like: 'MyVanna' object has no attribute 'temperature'
Hello, @yedhukr! :] Great that you were able to test the implementation!
I don't have access to Google Cloud, so I have no way of testing it. Perhaps someone could reach out, and I could borrow some API key such that I could debug this properly? Just for this PR, then the key could be rotated. @zainhoda?
@andreped Let me know if there's anything else I can do to help!
Modifying the function in this manner gets it to run, but I get a response like: 'Hi there, how can I help you today?'
def submit_prompt(self, prompt, **kwargs) -> str:
temperature = 0.7
max_tokens = 500
top_p = 0.95
top_k = 40
chat_model = ChatModel.from_pretrained("chat-bison@001")
chat = chat_model.start_chat()
params = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
}
response = chat.send_message("{prompt}", **params)
return response.text
Modifying the function in this manner gets it to run, but I get a response like: 'Hi there, how can I help you today?'
Which user prompt did you use? It also sounds like you are missing the system message that Vanna uses.
I think by doing this chat = chat_model.start_chat()
, you are not setting the system message. Here you can see an example which includes a system message:
https://cloud.google.com/vertex-ai/generative-ai/docs/sdk-for-llm/sdk-use-text-models#generate-text-chat-sdk
As a test, could you try to feed the system message that the Vanna Base class uses here: https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py#L448
Basically, change chat = chat_model.start_chat()
to:
chat = chat_model.start_chat(
context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
)
If that works, I think I know how to fix the issue.
EDIT: It could also be that you just wrote "Hello"
as the user prompt, and in that case I think I have gotten the same reply, even with Azure OpenAI and pretrained Chroma instance. Write a more advanced question and see if it produces a query.
@andreped Alright I made the change as you mentioned, here is the entire code and response for reference:
[{'role': 'system', 'content': 'The user provides a question and you provide SQL...;} ... {'role': 'user', 'content': 'What are the top 5 properties by sales in 2023?'}] What are the average sales per store in each state? What are the average sales per store in each state? Couldn't run sql: Execution failed on sql 'What are the average sales per store in each state?': near "What": syntax error
class Palm(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)
if client is not None:
self.client = client
return
# default values for params
temperature = 0.7
max_tokens = 500
top_p = 0.95
top_k = 40
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text
def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)
# Replace "\_" with "_"
sql = sql.replace("\\_", "_")
sql = sql.replace("\\", "")
return self.extract_sql_query(sql)
def submit_prompt(self, prompt, **kwargs) -> str:
temperature = 0.7
max_tokens = 500
top_p = 0.95
top_k = 40
chat_model = ChatModel.from_pretrained("chat-bison@001")
chat = chat_model.start_chat(
context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
)
params = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
}
response = chat.send_message("{prompt}", **params)
return response.text
class MyVanna(VannaDB_VectorStore, Palm):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model="***", vanna_api_key="***", config=config)
Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)
print(self)
vn = MyVanna()
vn.connect_to_sqlite('database/rl_database.sqlite')
vn.ask("What are the top 5 properties by sales in 2023?")
Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.
Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.
No problem. I will check around if I can get a new trial. Maybe I just need to setup a new account :P
Hi. I want to try vanna ai on PaLM API models (bison). Do you have any tutorial or documentation on how to set up those models on vanna?. It is not clear to me how to implement any other model if you choose "Ohter LLM" in the configuration options. Here is the code i can't figure it out how to adapt to PaLM API models: