open-webui / pipelines

Pipelines: Versatile, UI-Agnostic OpenAI-Compatible Plugin Framework
MIT License
301 stars 71 forks source link

question - can I use pipelines to add a custom document parser? #39

Open bryanhj opened 4 weeks ago

bryanhj commented 4 weeks ago

My use-case is that I'd like to perform RAG with some documents that are not included in the list of imported langchain_community.document_loaders. I have a working loader based on llama_index Document/BaseReader that I could refactor but I don't see a way to add it without forking. The file format is internal/private and it wouldn't make sense to be added to the main open-webui project.

I found in the examples where I could implement an entire RAG pipeline and that's what I initially planned to do but the built in document handling is excellent and it appears I lose most of that functionality if I go this route so I thought it was worth asking about.

tjbck commented 4 weeks ago

https://github.com/open-webui/pipelines/issues/9 I believe this would cover your use case!

MarlNox commented 4 weeks ago

I used the RAG API after i copied the webui.db and the chroma.sqlite, into the pipelines folder. set up a user API key, and use that. to interact with the api. We fetch file names and file collections, and use the query collection to run queries using the api to the openwebui.

EXAMPLE CODE FOR Jupyter:

import json
import sqlite3
import requests

# Define the paths to the databases
db_path_chroma = './data/vector_db/chroma.sqlite3'
db_path_webui = './data/webui.db'

# Define the base URL of your FastAPI service
base_url = "http://localhost:8080/rag/api/v1"

# Define the bearer token for authentication
bearer_token = "sk-1234abc1234"

def extract_data_to_json(db_path_chroma, db_path_webui):
    # Connect to the chroma database
    conn_chroma = sqlite3.connect(db_path_chroma)
    cursor_chroma = conn_chroma.cursor()

    # Connect to the webui database
    conn_webui = sqlite3.connect(db_path_webui)
    cursor_webui = conn_webui.cursor()

    # SQL queries to extract data
    query_collections_chroma = """
    SELECT id, name
    FROM collections
    """

    query_document_webui = """
    SELECT id, collection_name, filename, title, name
    FROM document
    """

    # Execute the queries
    cursor_chroma.execute(query_collections_chroma)
    collections_chroma = cursor_chroma.fetchall()

    cursor_webui.execute(query_document_webui)
    documents_webui = cursor_webui.fetchall()

    # Close the connections
    conn_chroma.close()
    conn_webui.close()

    # Process the data into the desired format
    data = {}

    collection_map = {collection[1]: collection[0] for collection in collections_chroma}

    for doc in documents_webui:
        file_id, collection_name, filename, title, name = doc
        if filename not in data:
            data[filename] = {
                "filepath": filename,
                "collections": set(),
                "metadata": {
                    "title": title,
                    "name": name
                }
            }
        data[filename]["collections"].add(collection_name)

    # Convert sets to lists for JSON serialization
    for filename in data:
        data[filename]["collections"] = list(data[filename]["collections"])

    # Convert the data dictionary to a JSON string
    json_data = json.dumps(data, indent=4)

    return json_data

def get_status():
    url = f"{base_url}/"
    headers = {"Authorization": f"Bearer {bearer_token}"}
    response = requests.get(url, headers=headers)
    print(f"Status Response: {response.status_code}, {response.text}")  # Debugging line
    if response.status_code == 200:
        return response.json()
    else:
        return {"error": f"Failed to get status. HTTP Status Code: {response.status_code}"}

def query_collection(collection_names, query):
    url = f"{base_url}/query/collection"
    headers = {
        "Authorization": f"Bearer {bearer_token}",
        "Content-Type": "application/json"
    }
    payload = {
        "collection_names": collection_names,
        "query": query
    }
    response = requests.post(url, json=payload, headers=headers)
    print(f"Query Collection Response: {response.status_code}, {response.text}")  # Debugging line
    if response.status_code == 200:
        return response.json()
    else:
        return {"error": f"Failed to query collections. HTTP Status Code: {response.status_code}, Detail: {response.text}"}

def get_collections_from_filenames(json_data, filenames):
    data = json.loads(json_data)
    collection_names = set()

    for filename in filenames:
        if filename in data:
            collection_names.update(data[filename]["collections"])

    return list(collection_names)

if __name__ == "__main__":
    # Generate the JSON data from the databases
    json_result = extract_data_to_json(db_path_chroma, db_path_webui)
    print("Extracted JSON Data:", json_result)

    # Display all filenames for the user to choose from
    data = json.loads(json_result)
    all_filenames = list(data.keys())
    print("Available filenames:", all_filenames)

    # Get user input for filenames or "All"
    user_input = input("Enter 'All' to query all collections or a comma-separated list of filenames: ").strip()

    # Check if querying all collections or specific filenames
    if user_input.lower() == "all":
        # Extract all collection names from the JSON data
        all_collections = {collection for file in data.values() for collection in file["collections"]}
        collection_names = list(all_collections)
    else:
        # Get collection names from filenames
        filenames = [filename.strip() for filename in user_input.split(",")]
        collection_names = get_collections_from_filenames(json_result, filenames)

    # Get API status
    status = get_status()
    print("API Status:", status)

    # Get user input for the query
    query = input("Enter your query: ").strip()

    # Query the collections with the provided query
    collection_result = query_collection(collection_names, query)
    print("Query Collection Result:", collection_result)