Open ababic opened 6 months ago
Thinking on this in hindsight, maybe the most flexible approach here would be to accept a QuerySet
as an option, and results would be restricted to items within that queryset? With this approach, the above could be rewritten as:
related_blogs = index.similar(page, queryset=Page.objects.live().type(BlogPage, SpecialBlogPage))
related_events = index.similar(page, queryset=EventPage.objects.live())
I've attempted to implement this on a custom multiple-model index, but it's very specific to the pgvector
backend, and is dependent on using the Django queryset API for lookups.
from collections.abc import Iterable, Iterator
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import Model, QuerySet
from django.db.models.base import ModelBase
from django.db.models.functions import Cast
from wagtail.models import Page
from wagtail_vector_index.backends.pgvector.models import PgvectorEmbeddingQuerySet
from wagtail_vector_index.index.base import Document
from wagtail_vector_index.index.models import EmbeddableFieldsDocumentConverter, EmbeddableFieldsVectorIndex
class SimilarPageVectorIndex(EmbeddableFieldsVectorIndex):
# This would ideally be generated automatically by evaluating the contents
# of `querysets`, but I know this index will only ever be used for pages
base_model_class = Page
querysets = [
...
]
def get_converter(self, model_class: ModelBase | None = None):
"""
Overrides ``EmbeddableFieldsVectorIndex.get_converter()`` to accept the
``model_class`` argument, which is used to create a more 'specific' converter
instance when benefitial.
"""
return self.get_converter_class()(model_class or self.base_model_class)
def get_documents(self) -> Iterable[Document]:
querysets = self._get_querysets()
all_documents = []
for queryset in querysets:
instances = queryset.prefetch_related("embeddings")
all_documents += list(
self.get_converter(queryset.model).bulk_to_documents(
instances, embedding_backend=self.embedding_backend
)
)
return all_documents
def _get_content_types(self, *models, exact: bool = False):
if exact:
return ContentType.objects.get_for_models(*models).values()
all_subclasses = {
model for model in apps.get_models() if issubclass(model, models)
}
return ContentType.objects.get_for_models(*all_subclasses).values()
def similar(
self,
obj,
model_or_queryset: QuerySet | ModelBase | None = None,
*,
include_self: bool = False,
limit: int = 5,
pk_list: bool = False,
) -> list:
""""
Overrides ``EmbeddableFieldsVectorIndex.similar()`` to accept an optional
``model_or_queryset`` positional argument, which can be used to restrict
results to objects of a specific type, or to ensure results are suitable
for use in certain contexts.
It also adds the ``pk_list`` option, which allows the caller to opt out
of conversion of results back into model objects, and instead just
receive a list of primary keys.
""""
# Yes, this backend-agnostic index is taking over control of the queryset used
# to identify documents, which is not really it's responsibility
be_index = self.backend_index
embeddings_qs: QuerySet = be_index._get_queryset().select_related("embedding")
if isinstance(model_or_queryset, QuerySet):
embeddings_qs = embeddings_qs.filter(
embedding__object_id__in=model_or_queryset.annotate(
pk_as_string=Cast("pk", models.CharField())
).values_list("pk_as_string", flat=True),
embedding__content_type__in=self._get_content_types(
model_or_queryset.model
),
)
elif model_or_queryset is not None:
embeddings_qs = embeddings_qs.filter(
embedding__content_type__in=self._get_content_types(model_or_queryset)
)
if not include_self:
embeddings_qs = embeddings_qs.exclude(embedding__object_id=str(obj.pk))
similar_documents: list[Document] = []
seen_pks: set[str] = set()
for document in self.get_converter(type(object)).to_documents(
obj, embedding_backend=self.embedding_backend
):
for similar_document in self.get_similar_documents(
document.vector, embeddings_qs=embeddings_qs, limit=limit
):
if similar_document.metadata["object_id"] not in seen_pks:
similar_documents.append(similar_document)
seen_pks.add(similar_document.metadata["object_id"])
# TODO: similar_documents should really be re-sorted by similarity score
# now, as matches across multiple documents are currently chunked together
if pk_list:
# Using the base model class should be fine here, as it should be where the
# concrete 'pk' field is defined for all models in all querysets (this would only
# only fall down if mixing models with no common concrete ancestor)
return [
self.base_model_class._meta.pk.to_python(doc.metadata["object_id"])
for doc in similar_documents
]
# Continue to convert to model instances
if model_or_queryset is None:
target_class = None
elif issubclass(model_or_queryset, Model):
target_class = model_or_queryset
else:
target_class = model_or_queryset.model
converter = self.get_converter(target_class)
return converter.bulk_from_documents(similar_documents)
def get_similar_documents(
self,
query_vector,
embeddings_qs: PgvectorEmbeddingQuerySet,
*,
limit: int = 5,
) -> Iterator[Document]:
"""
An alternative to ``PgvectorIndex.similarity_search()``, which requires the
caller to provide a ``PgvectorEmbeddingQuerySet`` to search on.
"""
for pgvector_embedding in (
embeddings_qs.filter(embedding_output_dimensions=len(query_vector))
.order_by_distance(
query_vector,
distance_method=self.backend_index.distance_method,
fetch_distance=False,
)[:limit]
.iterator()
):
embedding = pgvector_embedding.embedding
yield embedding.to_document()
Where you have a custom index that contains rows for multiple models, it would be nice to be able to filter those lookups to include items of a specific type. For example:
This could work by converting the model class (or sequence of model classes) into a list of
ContentType
objects usingContentType.objects.get_for_models()
, then applying.filter(content_type__in=content_types)
when lookup up embeddings.Where a single model class has been specified, it would also be cool if the converter was initialised with that model class, so that you automatically receive instances of that type instead of a shared parent (e.g.
Page
)