vanna-ai / vanna

šŸ¤– Chat with your SQL database šŸ“Š. Accurate Text-to-SQL Generation via LLMs using RAG šŸ”„.
https://vanna.ai/docs/
MIT License
9.87k stars 724 forks source link

Goolge PaLM API models #239

Closed hugoferrero closed 3 months ago

hugoferrero commented 5 months ago

Hi. I want to try vanna ai on PaLM API models (bison). Do you have any tutorial or documentation on how to set up those models on vanna?. It is not clear to me how to implement any other model if you choose "Ohter LLM" in the configuration options. Here is the code i can't figure it out how to adapt to PaLM API models:

class MyCustomLLM(VannaBase):
  def __init__(self, config=None):
    pass

  def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
    # Implement here

  def generate_question(self, sql: str, **kwargs) -> str:
    # Implement here

  def get_followup_questions_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
    # Implement here

  def get_sql_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
    # Implement here

  def submit_prompt(self, prompt, **kwargs) -> str:
    # Implement here

class MyVanna(ChromaDB_VectorStore, MyCustomLLM):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        MyCustomLLM.__init__(self, config=config)

vn = MyVanna()
andreped commented 4 months ago

I am adding the following notebook, which I believe the code example above is derived from: https://github.com/vanna-ai/vanna/blob/fb384d43a0fb50a7cbe366cd6c242d5d37b16569/notebooks/bigquery-other-llm-chromadb.ipynb

Would be great to have a notebook describing how to do this or similar :]

andreped commented 4 months ago

Perhaps you want to take a look at the recently added Ollama implementation: https://github.com/vanna-ai/vanna/blob/fb384d43a0fb50a7cbe366cd6c242d5d37b16569/src/vanna/ollama/__init__.py#L6

I added the code below which should give you some ideas on what is required to add support to any other LLM model.

Perhaps that is exactly what you were looking for, @hugoferrero? :]

from ..base import VannaBase
import requests
import json
import re

class Ollama(VannaBase):
    def __init__(self, config=None):
        if config is None or 'ollama_host' not in config:
            self.host = "http://localhost:11434"
        else:
            self.host = config['ollama_host']

        if config is None or 'model' not in config:
            raise ValueError("config must contain a Ollama model")
        else:
            self.model = config['model']

    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def extract_sql_query(self, text):
        """
        Extracts the first SQL statement after the word 'select', ignoring case,
        matches until the first semicolon, three backticks, or the end of the string,
        and removes three backticks if they exist in the extracted string.

        Args:
        - text (str): The string to search within for an SQL statement.

        Returns:
        - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
        """
        # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
        pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)

        match = pattern.search(text)
        if match:
            # Remove three backticks from the matched string if they exist
            return match.group(0).replace('```', '')
        else:
            return text

    def generate_sql(self, question: str, **kwargs) -> str:
        # Use the super generate_sql
        sql = super().generate_sql(question, **kwargs)

        # Replace "\_" with "_"
        sql = sql.replace("\\_", "_")

        sql = sql.replace("\\", "")

        return self.extract_sql_query(sql)

    def submit_prompt(self, prompt, **kwargs) -> str:
        url = f"{self.host}/api/chat"
        data = {
            "model": self.model,
            "stream": False,
            "messages": prompt,
        }

        response = requests.post(url, json=data)

        response_dict = response.json()

        self.log(response.text)

        return response_dict['message']['content']
hugoferrero commented 4 months ago

Thanks for the response @andreped. I will try it, and send you feddback.

zainhoda commented 4 months ago

@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?

andreped commented 4 months ago

@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?

@zainhoda I am open to drafting a PR for :] I can tag you in, @hugoferrero, if you wish to test it before merging.

andreped commented 4 months ago

I made a PR https://github.com/vanna-ai/vanna/pull/264.

It is a rather simple implementation but sadly I do not have access to Google Cloud. I am therefore dependent on some of you to test it.

To install:

pip install git+https://github.com/andreped/vanna.git@bison-support
pip install chromadb google-cloud-aiplatform

Then you should be able to initialize it with a vector DB like Chroma like so:

from vertexai.language_models import ChatModel
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.palm.palm import Palm

class MyVanna(ChromaDB_VectorStore, Palm):
        def __init__(self, config=None):
            ChromaDB_VectorStore.__init__(self, config=config)
            Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)

vn = MyVanna()

# do as you normally would with any other client(s)
# [...]
yedhukr commented 4 months ago

@andreped I've tried the code and it's not working, it gives errors like: 'MyVanna' object has no attribute 'temperature'

I then tried explicitly mentioning the parameters as below:

def submit_prompt(self, prompt, **kwargs) -> str:
    temperature = 0.7
    max_tokens = 500
    top_p = 0.95
    top_k = 40

    params = {
        "temperature": temperature,
        "max_output_tokens": max_tokens,
        "top_p": top_p,
        "top_k": top_k,
    }

    response = self.client.send_message(prompt, **params)
    return response.text

This gave me the error: _'ChatModel' object has no attribute 'sendmessage'

I also attempted to specify the model more explicitly:

def submit_prompt(self, prompt, **kwargs) -> str:
    temperature = 0.7
    max_tokens = 500
    top_p = 0.95
    top_k = 40

    client = ChatModel("chat-bison@001")
    chat = client.start_chat()
    params = {
        "temperature": temperature,
        "max_output_tokens": max_tokens,
        "top_p": top_p,
        "top_k": top_k,
    }
    print(prompt)
    response = chat.send_message(prompt, **params)
    return response.text

This gave me this error: 400 Invalid resource field value in the request.

Please help šŸ™šŸ¼

andreped commented 4 months ago

@andreped I've tried the code and it's not working, it gives errors like: 'MyVanna' object has no attribute 'temperature'

Hello, @yedhukr! :] Great that you were able to test the implementation!

I don't have access to Google Cloud, so I have no way of testing it. Perhaps someone could reach out, and I could borrow some API key such that I could debug this properly? Just for this PR, then the key could be rotated. @zainhoda?

yedhukr commented 4 months ago

@andreped Let me know if there's anything else I can do to help!

Modifying the function in this manner gets it to run, but I get a response like: 'Hi there, how can I help you today?'

    def submit_prompt(self, prompt, **kwargs) -> str:
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40

        chat_model = ChatModel.from_pretrained("chat-bison@001")
        chat = chat_model.start_chat()
        params = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": top_p,
            "top_k": top_k,
        }

        response = chat.send_message("{prompt}", **params)
        return response.text
andreped commented 4 months ago

Modifying the function in this manner gets it to run, but I get a response like: 'Hi there, how can I help you today?'

Which user prompt did you use? It also sounds like you are missing the system message that Vanna uses.

I think by doing this chat = chat_model.start_chat(), you are not setting the system message. Here you can see an example which includes a system message: https://cloud.google.com/vertex-ai/generative-ai/docs/sdk-for-llm/sdk-use-text-models#generate-text-chat-sdk

As a test, could you try to feed the system message that the Vanna Base class uses here: https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py#L448

Basically, change chat = chat_model.start_chat() to:

chat = chat_model.start_chat(
    context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
)

If that works, I think I know how to fix the issue.


EDIT: It could also be that you just wrote "Hello" as the user prompt, and in that case I think I have gotten the same reply, even with Azure OpenAI and pretrained Chroma instance. Write a more advanced question and see if it produces a query.

yedhukr commented 4 months ago

@andreped Alright I made the change as you mentioned, here is the entire code and response for reference:

[{'role': 'system', 'content': 'The user provides a question and you provide SQL...;} ... {'role': 'user', 'content': 'What are the top 5 properties by sales in 2023?'}] What are the average sales per store in each state? What are the average sales per store in each state? Couldn't run sql: Execution failed on sql 'What are the average sales per store in each state?': near "What": syntax error

class Palm(VannaBase):
    def __init__(self, client=None, config=None):
        VannaBase.__init__(self, config=config)

        if client is not None:
            self.client = client
            return

        # default values for params
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40

    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def extract_sql_query(self, text):
        """
        Extracts the first SQL statement after the word 'select', ignoring case,
        matches until the first semicolon, three backticks, or the end of the string,
        and removes three backticks if they exist in the extracted string.

        Args:
        - text (str): The string to search within for an SQL statement.

        Returns:
        - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
        """
        # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
        pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

        match = pattern.search(text)
        if match:
            # Remove three backticks from the matched string if they exist
            return match.group(0).replace("```", "")
        else:
            return text

    def generate_sql(self, question: str, **kwargs) -> str:
        # Use the super generate_sql
        sql = super().generate_sql(question, **kwargs)

        # Replace "\_" with "_"
        sql = sql.replace("\\_", "_")

        sql = sql.replace("\\", "")

        return self.extract_sql_query(sql)

    def submit_prompt(self, prompt, **kwargs) -> str:
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40

        chat_model = ChatModel.from_pretrained("chat-bison@001")
        chat = chat_model.start_chat(
            context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
        )
        params = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": top_p,
            "top_k": top_k,
        }

        response = chat.send_message("{prompt}", **params)
        return response.text

class MyVanna(VannaDB_VectorStore, Palm):
    def __init__(self, config=None):
        VannaDB_VectorStore.__init__(self, vanna_model="***", vanna_api_key="***", config=config)
        Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)
        print(self)

vn = MyVanna()
vn.connect_to_sqlite('database/rl_database.sqlite')
vn.ask("What are the top 5 properties by sales in 2023?")
hugoferrero commented 4 months ago

Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.

andreped commented 4 months ago

Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.

No problem. I will check around if I can get a new trial. Maybe I just need to setup a new account :P