tsterbak / promptmage

simplifies the process of creating and managing LLM workflows.
MIT License
68 stars 4 forks source link

[BUG] template_vars not being stored correctly in the SQL Database #5

Closed LuciAkirami closed 3 weeks ago

LuciAkirami commented 3 weeks ago

These are the results when querying the database before and after editing a prompt in the UI

id name system user version template_vars
af238ea6-7ac1-4eab-9073-fad88c52342f qa_bot You are a helpful AI assistant Answer the following user question:{query} 1 query
id name system user version template_vars
af238ea6-7ac1-4eab-9073-fad88c52342f qa_bot You are a helpful AI assistant Answer the following human question:{query} 2 q,u,e,r,y

I got to know where the issue. It's in the sqlite_backend.py in the update_prompt function, line 115

latest_prompt.template_vars = ",".join(prompt.template_vars)

Here the prompt.template_vars is already a string, and calling join on it will result in the above behaviour. This is easily fixable

Can I create a pull request on this?

LuciAkirami commented 3 weeks ago

Actually, it doesn't end here. I tried to trace it from the beginning. So, when I click on save in the UI. The following code gets triggered

frontend/components/prompts_page.py - Line 89

def save_prompt(prompt_id: str, system: str, user: str):
        prompt = mage.prompt_store.get_prompt_by_id(prompt_id)
        prompt.system = system
        prompt.user = user

Here, in the first line, prompt_store.get_prompt_by_id() gets called , if we look into that

storage/prompt_store.py - Line 40

    def get_prompt_by_id(self, prompt_id: str) -> Prompt:
        logger.info(f"Retrieving prompt with ID {prompt_id}")
        return self.backend.get_prompt_by_id(prompt_id)

This is calling the get_prompt_by_id()method of the backend class. So if we take a look at that

storage/sqlite_backend.py - Line 143

    def get_prompt_by_id(self, prompt_id: str) -> Prompt:
        session = self.Session()
            prompt = session.execute(
                select(PromptModel).where(PromptModel.id == prompt_id)
            if prompt is None:
                raise PromptNotFoundException(f"Prompt with ID {prompt_id} not found.")
            return prompt

Here is the issue. If we take a close look at the function. It must be returning an object of type Prompt. But it returns an object of type PromptModel. And the PromptModel object stores the template_vars variable as str and Prompt object stores it as a list

So, lets say it returns a PromptModel object. Then moving to our save_prompt() in components/prompts_page.py

frontend/components/prompts_page.py - Line 89

def save_prompt(prompt_id: str, system: str, user: str):
        prompt = mage.prompt_store.get_prompt_by_id(prompt_id)
        prompt.system = system
        prompt.user = user
->     mage.prompt_store.update_prompt(prompt)

So the mage.prompt_store.get_prompt_by_id(prompt_id) returns a PromptModelobject and we pass this PromptModel object to the mage.prompt_store.update_prompt(prompt). Now taking a look at this function in sqlite_backend.py

storage/sqlite_backend.py - Line 94

    def update_prompt(self, prompt: Prompt):
        session = self.Session()
            existing_prompt = (
                    select(PromptModel).where(PromptModel.name == prompt.name)

            if not existing_prompt:
                raise PromptNotFoundException(
                    f"Prompt with name {prompt.name} not found."

            latest_prompt = max(existing_prompt, key=lambda p: p.version)

            latest_prompt.version += 1
            latest_prompt.system = prompt.system
            latest_prompt.user = prompt.user
->         latest_prompt.template_vars = ",".join(prompt.template_vars)

        except SQLAlchemyError as e:
            logger.error(f"Error updating prompt: {e}")

The update_prompt() method excepts an object of type Prompt and we are passing it an object of type PromptModel. And the issue comes here in the latest_prompt.template_vars = ",".join(prompt.template_vars). If it had been a Prompt object, this code works fine and all the tempate_vars in the list are converted to a single str. But as we are passing a PromptModel object where the template_vars is already a str, it is again splitting the individual characters in that str and joining these characters/letters with a comma(,), in the above case, where query as split to q,u,e,r,y

So the potential fix for this is in the get_prompt_by_id()method of the SQLiteBackend class in the sqlite_backend.py

Fix storage/sqlite_backend.py - Line 143

    def get_prompt_by_id(self, prompt_id: str) -> Prompt:
        session = self.Session()
            prompt_model = session.execute(
                select(PromptModel).where(PromptModel.id == prompt_id)
            if prompt_model is None:
                raise PromptNotFoundException(f"Prompt with ID {prompt_id} not found.")
            # FIX - Converting PromptModel to Prompt
            prompt = Prompt(
                template_vars= prompt_model.template_vars.split(',')
            # FIX - Converting PromptModel to Prompt
            return prompt

So here, we are retrieving an object of PromptModel from the SQL database and converting it to the type Prompt Object. We are taking the str value in the prompt_model.template_vars and converting it to a list by splitting it at comma(,)

If you allow, I can create a pull request for this change :smile:

tsterbak commented 3 weeks ago

Great findings! :)

Please go ahead and create a pull request! That would be very welcome! :partying_face:

LuciAkirami commented 3 weeks ago

Sure thank you. Will create it

LuciAkirami commented 3 weeks ago

Closing this issue and its resolved