holoviz / lumen

Illuminate your data.
https://lumen.holoviz.org
BSD 3-Clause "New" or "Revised" License
177 stars 20 forks source link

Create Embeddings and VectorStore #750

Open ahuang11 opened 6 days ago

ahuang11 commented 6 days ago

I am planning to refactor the existing Embeddings class.

The purpose is to supply LLMs with up-to-date or private data using retrieval-augmented generation (RAG) because LLMs are trained on static and sometimes outdated datasets, which may not provide accurate information.

Goals:


I propose the following interfaces with minimal methods as to prevent the interface from being too rigid. I did not add delete to the interface because some of the stores, like numpy, are not persistent. DuckDb will have it though.

class Embeddings(ABC):
    @abstractmethod
    def embed(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for a list of texts."""
        pass

class VectorStore(ABC):
    def __init__(self, embedding_model: 'Embeddings'):
        self.embedding_model = embedding_model

    @abstractmethod
    def add(self, texts: List[str], metadata: Optional[List[Dict]] = None) -> List[int]:
        """
        Add texts and their metadata to the store.
        Returns:
            List[int]: A list of unique text IDs for the added texts.
        """
        pass

    @abstractmethod
    def query(self, text: str, top_k: int = 5) -> List[Dict]:
        """
        Query store for similar texts.
        Returns:
            List[Dict]: List of matching texts with metadata and similarity scores.
        """
        pass

Then, I am planning to implement these:

OpenAI Embeddings MistralAI Embeddings HuggingFace Embeddings WordLlama Embeddings

Example:

class OpenAIEmbeddings(Embeddings):
    def __init__(self, api_key: str, model: str = 'text-embedding-3-small'):
        from openai import OpenAI
        self.client = OpenAI()

    def embed(self, texts: List[str]) -> List[List[float]]:
        texts = [text.replace("\n", " ") for text in texts]
        response = self.client.embeddings.create(input=texts, model=self.model)
        return [r.embedding for r in response.data]

With these stores:

NumpyVectorStore (memory) DuckDBVectorStore (persistent) WordLlamaVectorStore (memory) ChromaVectorStore (variety)

Example:

import duckdb
import json
from typing import List, Optional, Dict

class DuckDBVectorStore(VectorStore):
    def __init__(self, embedding_model: 'Embeddings', db_path: str = ':memory:'):
        super().__init__(embedding_model)
        self.connection = duckdb.connect(database=db_path)
        self._setup_database()

    def _setup_database(self) -> None:
        self.connection.execute("""
            CREATE TABLE IF NOT EXISTS documents (
                id BIGINT AUTO_INCREMENT PRIMARY KEY,
                text VARCHAR,
                embedding FLOAT[],
                table_name VARCHAR,
                metadata JSON
            );
        """)
        self.connection.execute("""
            CREATE INDEX IF NOT EXISTS embedding_index 
            ON documents USING HNSW (embedding) WITH (metric = 'cosine');
        """)

    def add(self, texts: List[str], metadata: Optional[List[Dict]] = None, table_name: str = "default") -> List[int]:
        embeddings = self.embedding_model.embed(texts)
        text_ids = []
        for i, (text, embedding) in enumerate(zip(texts, embeddings)):
            meta = metadata[i] if metadata else {}
            result = self.connection.execute("""
                INSERT INTO documents (text, embedding, table_name, metadata)
                VALUES (?, ?, ?, ?) RETURNING id;
            """, [text, embedding, table_name, json.dumps(meta)])
            text_ids.append(result.fetchone()[0])  # Fetch and collect the generated IDs
        return text_ids

    def delete(self, text_ids: List[int]) -> None:
        self.connection.execute("""
            DELETE FROM documents WHERE id IN ?;
        """, (tuple(text_ids),))

    def query(self, text: str, top_k: int = 5, table_name: Optional[str] = None) -> List[Dict]:
        query_embedding = self.embedding_model.embed([text])[0]
        if table_name:
            result = self.connection.execute("""
                SELECT id, text, metadata,
                       cosine_distance(embedding, ?) AS similarity
                FROM documents
                WHERE table_name = ?
                ORDER BY similarity ASC
                LIMIT ?;
            """, [query_embedding, table_name, top_k]).fetchall()
        else:
            result = self.connection.execute("""
                SELECT id, text, metadata,
                       cosine_distance(embedding, ?) AS similarity
                FROM documents
                ORDER BY similarity ASC
                LIMIT ?;
            """, [query_embedding, top_k]).fetchall()

        return [{"id": row[0], "text": row[1], "metadata": json.loads(row[2]), "similarity": row[3]} for row in result]

    def lookup_text_ids(self, texts: Optional[List[str]] = None, metadata: Optional[Dict] = None) -> List[int]:
        query = "SELECT id FROM documents WHERE 1=1"
        params = []

        if texts:
            query += " AND text IN ?"
            params.append(tuple(texts))

        if metadata:
            query += " AND metadata @> ?"
            params.append(json.dumps(metadata))

        result = self.connection.execute(query, params).fetchall()
        return [row[0] for row in result]
philippjfr commented 1 day ago

def add(self, texts: List[str], metadata: Optional[List[Dict]] = None, table_name: str = "default") -> List[int]:

Maybe combine text and metadata into one list and then let's just shove the table_name into the metadata for now.