wagtail / wagtail-vector-index

Store Wagtail pages & Django models as embeddings in vector databases
https://wagtail-vector-index.readthedocs.io/en/latest/
MIT License
20 stars 13 forks source link

Support type-specific lookups for `similar()` and `search()` #60

Open ababic opened 6 months ago

ababic commented 6 months ago

Where you have a custom index that contains rows for multiple models, it would be nice to be able to filter those lookups to include items of a specific type. For example:

index = MyCustomIndex()

# Would return instances of BlogPage and SpecialBlogPage only
related_blogs = index.similar(page, object_type=[BlogPage, SpecialBlogPage])

# Would return instances of EventPage only
related_events = index.search("Soft play", object_type=EventPage)

This could work by converting the model class (or sequence of model classes) into a list of ContentType objects using ContentType.objects.get_for_models(), then applying .filter(content_type__in=content_types) when lookup up embeddings.

Where a single model class has been specified, it would also be cool if the converter was initialised with that model class, so that you automatically receive instances of that type instead of a shared parent (e.g. Page)

ababic commented 6 months ago

Thinking on this in hindsight, maybe the most flexible approach here would be to accept a QuerySet as an option, and results would be restricted to items within that queryset? With this approach, the above could be rewritten as:

related_blogs = index.similar(page, queryset=Page.objects.live().type(BlogPage, SpecialBlogPage))

related_events = index.similar(page, queryset=EventPage.objects.live())
ababic commented 6 months ago

I've attempted to implement this on a custom multiple-model index, but it's very specific to the pgvector backend, and is dependent on using the Django queryset API for lookups.

from collections.abc import Iterable, Iterator

from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import Model, QuerySet
from django.db.models.base import ModelBase
from django.db.models.functions import Cast
from wagtail.models import Page
from wagtail_vector_index.backends.pgvector.models import PgvectorEmbeddingQuerySet
from wagtail_vector_index.index.base import Document
from wagtail_vector_index.index.models import EmbeddableFieldsDocumentConverter, EmbeddableFieldsVectorIndex

class SimilarPageVectorIndex(EmbeddableFieldsVectorIndex):

    # This would ideally be generated automatically by evaluating the contents
    # of `querysets`, but I know this index will only ever be used for pages
    base_model_class = Page

    querysets = [
        ...
    ]

    def get_converter(self, model_class: ModelBase | None = None):
        """
        Overrides ``EmbeddableFieldsVectorIndex.get_converter()`` to accept the 
        ``model_class`` argument, which is used to create a more 'specific' converter
        instance when benefitial.
        """
        return self.get_converter_class()(model_class or self.base_model_class)

    def get_documents(self) -> Iterable[Document]:
        querysets = self._get_querysets()
        all_documents = []

        for queryset in querysets:
            instances = queryset.prefetch_related("embeddings")
            all_documents += list(
                self.get_converter(queryset.model).bulk_to_documents(
                    instances, embedding_backend=self.embedding_backend
                )
            )
        return all_documents

    def _get_content_types(self, *models, exact: bool = False):
        if exact:
            return ContentType.objects.get_for_models(*models).values()
        all_subclasses = {
            model for model in apps.get_models() if issubclass(model, models)
        }
        return ContentType.objects.get_for_models(*all_subclasses).values()

    def similar(
        self,
        obj,
        model_or_queryset: QuerySet | ModelBase | None = None,
        *,
        include_self: bool = False,
        limit: int = 5,
        pk_list: bool = False,
    ) -> list:
        """"    
        Overrides ``EmbeddableFieldsVectorIndex.similar()`` to accept an optional 
        ``model_or_queryset`` positional argument, which can be used to restrict
        results to objects of a specific type, or to ensure results are suitable
        for use in certain contexts. 

        It also adds the ``pk_list`` option, which allows the caller to opt out
        of conversion of results back into model objects, and instead just
        receive a list of primary keys.
        """"

        # Yes, this backend-agnostic index is taking over control of the queryset used
        # to identify documents, which is not really it's responsibility
        be_index = self.backend_index
        embeddings_qs: QuerySet = be_index._get_queryset().select_related("embedding")

        if isinstance(model_or_queryset, QuerySet):
            embeddings_qs = embeddings_qs.filter(
                embedding__object_id__in=model_or_queryset.annotate(
                    pk_as_string=Cast("pk", models.CharField())
                ).values_list("pk_as_string", flat=True),
                embedding__content_type__in=self._get_content_types(
                    model_or_queryset.model
                ),
            )
        elif model_or_queryset is not None:
            embeddings_qs = embeddings_qs.filter(
                embedding__content_type__in=self._get_content_types(model_or_queryset)
            )
        if not include_self:
            embeddings_qs = embeddings_qs.exclude(embedding__object_id=str(obj.pk))

        similar_documents: list[Document] = []
        seen_pks: set[str] = set()
        for document in self.get_converter(type(object)).to_documents(
            obj, embedding_backend=self.embedding_backend
        ):
            for similar_document in self.get_similar_documents(
                document.vector, embeddings_qs=embeddings_qs, limit=limit
            ):
                if similar_document.metadata["object_id"] not in seen_pks:
                    similar_documents.append(similar_document)
                    seen_pks.add(similar_document.metadata["object_id"])

        # TODO: similar_documents should really be re-sorted by similarity score
        # now, as matches across multiple documents are currently chunked together

        if pk_list:
            # Using the base model class should be fine here, as it should be where the
            # concrete 'pk' field is defined for all models in all querysets (this would only
            # only fall down if mixing models with no common concrete ancestor)
            return [
                self.base_model_class._meta.pk.to_python(doc.metadata["object_id"])
                for doc in similar_documents
            ]

        # Continue to convert to model instances
        if model_or_queryset is None:
            target_class = None
        elif issubclass(model_or_queryset, Model):
            target_class = model_or_queryset
        else:
            target_class = model_or_queryset.model
        converter = self.get_converter(target_class)
        return converter.bulk_from_documents(similar_documents)

    def get_similar_documents(
        self,
        query_vector,
        embeddings_qs: PgvectorEmbeddingQuerySet,
        *,
        limit: int = 5,
    ) -> Iterator[Document]:
        """
        An alternative to ``PgvectorIndex.similarity_search()``, which requires the
        caller to provide a ``PgvectorEmbeddingQuerySet`` to search on.
        """
        for pgvector_embedding in (
            embeddings_qs.filter(embedding_output_dimensions=len(query_vector))
            .order_by_distance(
                query_vector,
                distance_method=self.backend_index.distance_method,
                fetch_distance=False,
            )[:limit]
            .iterator()
        ):
            embedding = pgvector_embedding.embedding
            yield embedding.to_document()