Chainlit / chainlit

Build Conversational AI in minutes ⚡️
https://docs.chainlit.io
Apache License 2.0
6.76k stars 878 forks source link

`SQLAlchemyDataLayer` is missing a `get_element` method #1205

Open Simon-Stone opened 1 month ago

Simon-Stone commented 1 month ago

The SQLAlchemyDataLayer is missing a get_element method. Interacting with elements is therefore broken when this data layer is used.

I was trying to fix this, mirroring the implementation in ChainlitDataLayer. It looks like the method needs to return an ElementDict based on the thread_id and element_id. This maps to the columns threadId and id in the elements SQL table.

    async def get_element(
        self, thread_id: str, element_id: str
    ) -> Optional["ElementDict"]:
        query = f"""SELECT * FROM elements WHERE "id" = :element_id AND "threadId" = :thread_id"""
        parameters = {"element_id": element_id, "thread_id": thread_id}
        records = await self.execute_sql(query=query, parameters=parameters)
        if not records:
            return None
        element = records[0]
        return ElementDict(**element)

I figured that part out, but then the elements don't actually show up. I am not getting any errors, though. I am a bit lost here, because I have a hard time tracing the execution after the ElementDict is returned because of all the async code, which is new to me.

How does this ElementDict interact with the storage? How and where is the actual Element created? Maybe this has something to do with the LocalStorageClient I am using:

from chainlit.data import BaseStorageClient

import mimetypes
from pathlib import Path
from typing import Any, Coroutine, Union

async def write_file(
    root: str | Path, object_key: str | Path, content: str | bytes, mime: str
):
    path = Path(root) / (str(object_key) + mimetypes.guess_extension(mime))
    Path(path).parent.mkdir(parents=True, exist_ok=True)

    if isinstance(content, bytes):
        with open(path, "wb") as f:
            f.write(content)
    else:
        with open(path, "w") as f:
            f.write(content)

    return {"object_key": object_key, "path": str(path.absolute())}

class LocalStorageClient(BaseStorageClient):

    def __init__(self, dir: Union[str, Path]) -> None:
        super().__init__()
        self.dir = Path(dir)
        self.dir.mkdir(parents=True, exist_ok=True)

    async def upload_file(
        self,
        object_key: str,
        data: Union[bytes, str],
        mime: str = "application/octet-stream",
        overwrite: bool = True,
    ) -> Coroutine[Any, Any, Any]:

        return await write_file(
            root=self.dir, object_key=object_key, content=data, mime=mime
        )

Any help would be appreciated!

AidanShipperley commented 1 month ago

Hi @Simon-Stone,

I encountered this exact same issue a month ago, and it's unfortunately really complicated. While I do have a solution, I have no idea if this is the correct or optimal one.

Like you suggested, the issue you're describing actually has everything to do with your LocalStorageClient you're using for element storage. I also wanted to do the exact same thing as you, because I am implementing this for my company, and we didn't want to use any cloud storage.

Let me explain the conflict that you are encountering as simply as possible:

Why is Chainlit only able to use URL to display elements? Element has a path parameter, why can't we just store our elements in that like you are attempting to do? The issue here is that neither the ElementDict class nor the SQL table schema actually store the path value for Element at all. That means your code currently running return {"object_key": object_key, "path": str(path.absolute())} isn't actually doing anything with the path you are passing it. Take a look at the create_element() function in sql_alchemy.py, which uses your upload_file() function. It is actually just inserting url and object_key into your SQL table if you look here.

I would say just add path as a column in your SQL table schema and get it to work, but you would need to change every instance of Chainlit using URL from element throughout the entire project, and at that point you would effectively be forking it for your use case.

Thus, you need your elements to be accessible via a URL that corresponds to the file path that they reside within. To do this in my case, I basically put all of my storage elements into a folder under the public folder, which is mounted to the app with FastAPI, meaning its contents are accessible from the client via a URL.

Then, you just adjust your create_element() functions and your upload_file() functions to save elements and then in SQL save the URL that accesses the element instead of the path.

Here is the code I wrote to do that:

@queue_until_user_message()
async def create_element(self, element: "Element"):
    logger.debug(f"DataLayer: create_element(element_id = '{element.id}')")
    if not getattr(context.session.user, 'id', None):
        raise ValueError("No authenticated user in context")
    if not self.storage_provider:
        logger.warn(f"DataLayer: create_element error. No storage_client is configured!")
        return
    if not element.for_id:
        return

    content: Optional[Union[bytes, str]] = None

    if element.path:
        async with aiofiles.open(element.path, "rb") as f:
            content = await f.read()
    elif element.url:
        async with aiohttp.ClientSession() as session:
            async with session.get(element.url) as response:
                if response.status == 200:
                    content = await response.read()
                else:
                    content = None
    elif element.content:
        content = element.content
    else:
        raise ValueError("Element url, path or content must be provided")

    if content is None:
        raise ValueError("Content is None, cannot upload file")

    context_user = context.session.user

    user_folder = getattr(context_user, "id", "unknown")
    file_object_key = os.path.join(user_folder, element.id)
    if element.name:
        file_object_key = os.path.join(file_object_key, element.name)

    if not element.mime:
        element.mime = "application/octet-stream"

    uploaded_file = await self.storage_provider.upload_file(
        object_key=file_object_key, data=content, mime=element.mime, overwrite=True
    )
    if not uploaded_file:
        raise ValueError(
            "DataLayer Error: create_element, Failed to persist data in storage_provider"
        )

    element_dict: ElementDict = element.to_dict()

    element_dict["url"] = uploaded_file.get("url")
    element_dict["objectKey"] = uploaded_file.get("object_key")
    element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}

    # ... Upsert into SQL

And my custom storage client:

from chainlit.config import config, DEFAULT_HOST
import mimetypes
import os

class FSStorageClient(BaseStorageClient):
    """
    Class to enable File System storage for ChainLit elements.
    """
    def __init__(self, storage_path: str, url_path: str):
        self.storage_path = storage_path
        self.url_path = url_path
        if not os.path.exists(self.storage_path):
            os.makedirs(self.storage_path, exist_ok=True)

        # Get serving URL
        host = config.run.host
        port = config.run.port

        if host == DEFAULT_HOST:
            self.url = f"http://localhost:{port}{os.environ.get('CHAINLIT_ROOT_PATH', '')}"
        else:
            self.url = f"http://{host}:{port}{os.environ.get('CHAINLIT_ROOT_PATH', '')}"

    async def upload_file(self, object_key: str, data: Union[bytes, str],
                          mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:

        try:
            # Clean file key and attempt to steal extension
            object_key, s_existing_extension = os.path.splitext(object_key)

            if s_existing_extension == "":
                # Guess extension if there is none
                s_file_extension = mimetypes.guess_extension(mime)
            else:
                s_file_extension = s_existing_extension
            s_object_key_final = object_key + s_file_extension
            s_object_key_url = s_object_key_final.replace("\\", "/")

            s_file_path = os.path.join(self.storage_path, s_object_key_final)

            # Ensure directory exists, Python does not create them automatically
            os.makedirs(os.path.dirname(s_file_path), exist_ok=True)

            # If we should not overwrite, fail if file exists
            if not overwrite and os.path.exists(s_file_path):
                return {}

            logger.debug(f"FSStorageClient, uploading file to: '{s_file_path}'")

            # Open the file in binary write mode
            async with aiofiles.open(s_file_path, "wb") as f:
                # Check if data is of type str, if yes, convert to bytes
                if isinstance(data, str):
                    data = data.encode('utf-8')
                await f.write(data)

            # Calculate URL for this file
            s_file_url = f"{self.url}/{self.url_path}/{s_object_key_url}"

            logger.debug(f"FSStorageClient, saving access URL as: '{s_file_url}'")

            return {"object_key": s_object_key_final, "url": s_file_url}

        except Exception as e:
            logger.warn(f"FSStorageClient, upload_file error: {e}")
            return {}

Then, you define your storage client as follows in your app.py:

fs_storage_client = FSStorageClient(
    storage_path=os.path.join(os.getcwd(), "public", "storage"),
    url_path="public/storage"
)

Hopefully this helps you get it working!

Simon-Stone commented 1 month ago

Wow, thank you so much for the thorough explanation and the detailed example code! So much appreciated! I will give this a try ASAP.

Simon-Stone commented 1 month ago

Still parsing what you shared here. Did you also create a get_element method for the data layer? Or am I misunderstanding how retrieving an element should work?

AidanShipperley commented 1 month ago

Still parsing what you shared here. Did you also create a get_element method for the data layer? Or am I misunderstanding how retrieving an element should work?

@Simon-Stone Oh so technically yes I did, but I am pretty sure it's never used. I have mine here for reference though:

async def get_element(self, thread_id: str, element_id: str) -> Optional[ElementDict]:
    logger.debug(f"DataLayer: get_element, thread_id='{thread_id}', element_id='{element_id}'")
    s_query = "SELECT TOP 1 * FROM dbo.elements WHERE id = :id"
    d_parameters = {"id": element_id}
    ld_result = await self.execute_sql(query=s_query, parameters=d_parameters)
    logger.debug(f"DataLayer: get_element, element result: {ld_result}")
    if ld_result and isinstance(ld_result, list):
        d_element_data = ld_result[0]
        return ElementDict(**d_element_data)
    return None

I just ran through some tests to confirm and it never gets called from what I can tell. The reason I think they don't need it is because for the current chat window, they just use a temporary folder under .files to store images and display them in the current chat, and there would be no need to retrieve from the SQL database to get your element data. They already can have images display in the chat when you don't have a data layer at all, so they had to make support for this functionality. In the case when you resume old chats and need the element data, they're always calling the get_all_user_threads() method. In that method, they are already getting all of the elements in all of the user's threads and then adding those elements to the thread's data.

However, if you aren't seeing any elements show up whatsoever in the current chat window, you may just have an issue with your implementation. If you shared some of your app.py code I could skim over it.

Simon-Stone commented 1 month ago

Thank you so much, once again! It turns out that I did need the get_element() method in the SQLAlchemyDataLayer class.

So in summary:

The former is on the user, since Chainlit does not support a local storage client (why not?). The latter remains an issue, though!

I really appreciate your thorough help and support here, @AidanShipperley !!

AidanShipperley commented 3 weeks ago

@Simon-Stone Nice, I'm glad you figured it out! You provided a very succinct explanation, I appreciate that.

I imagine not supporting local storage is just one of those things you can't predict in early development; you're trying to think about every possible thing users may want and you have to make some assumptions or you'll develop forever and never release your project (I'm a victim of this). I can see users saying the exact same thing if Chainlit only supported local file storage. Chainlit is fantastic in terms of feature completeness, it's just going to take a long time to polish out all of the loose ends.

I am really surprised that implementing that get_element() ended up working for you, I ran through everything element related in my chat app (having AI send an image/audio, I sent image/audio, I tried returning to an old chat with elements and resuming it) but nothing triggered my get_element() function. Do you notice what actions end up triggering the function in your app? I'd love to see if maybe I just missed something.

Thank you!