RupertAvery / DiffusionToolkit

Metadata-indexer and Viewer for AI-generated images
MIT License
752 stars 46 forks source link

Prompt search - Cosine similarity #190

Open RupertAvery opened 9 months ago

RupertAvery commented 9 months ago

I'd like to be able to search prompts by similarity. I was thinking of computing a vector of each prompt, and then calculating the cosine similarity. Unfortunately I don't have a good idea of how to do this at all, just the concepts.

I suppose I need to tokenize the prompt, and then convert it to a vector using something like Word2Vec.

danrazor commented 8 months ago

I searched BARD and here are some tips from BARD.

Implementing prompt similarity search in your Stable Diffusion image database app using vector representations and cosine similarity can significantly enhance user experience and image discovery. Here's how you can achieve it:

  1. Image Embeddings:

Instead of text prompts, extract vector representations (embeddings) directly from the Stable Diffusion images in your database. This eliminates the need for text tokenization and word embedding models. Consider using pre-trained image embedding models like CLIP, ViT, or Swin Transformer. These models excel at capturing the semantic content of images.

  1. Indexing and Searching:

Store the extracted image embeddings alongside the corresponding prompts and other metadata (tags, categories, etc.) in a searchable database (e.g., Elasticsearch, Faiss). When a user searches for a prompt, create a new embedding for the prompt text using the same model used for the images. Compare the query embedding with the stored image embeddings using cosine similarity.

  1. Ranking and Retrieval:

Rank images based on their cosine similarity scores with the query embedding. Images with higher scores are deemed more relevant to the search query. Return a list of the top-ranked images to the user.

  1. Advanced Techniques:

Implement hybrid search by combining textual and visual features. This allows users to refine their search with both text and image examples. Consider personalized search by incorporating user preferences and interaction history to fine-tune the retrieved results. Benefits:

Efficient and accurate image retrieval based on semantic understanding. Enables discovery of visually similar images, even for nuanced prompts. Offers a more intuitive and interactive search experience for users

curiousjp commented 8 months ago

It is a shame that we don’t have the vector of CLIPped tokens available in the file metadata. One approach could be to re-CLIP the prompt during metadata ingestion and hold it in the database (could be slow), but you would also need to code special cases for removing things like a1111 attention notation / BREAK etc.

After this, you’d have n 768 dimensional vectors representing your prompt (where n is some multiple of 77 due to comfy and a1111s support for infinite prompt length).

These could potentially be reduced to a single vector (just sum them?) and then cosine similarity could be calculated.

Alternatively, after stripping away all the notation you could build your own encoding over all the known words in the database instead of using CLIP. I’m not sure if this is likely to have a higher or lower dimension as it would depend on the prompts a user generates.

I might look at putting something like this together in Python to experiment with. Do you see this as part of an “order by similarity to a prompt I provide” feature or something else? I tried to solve this problem locally by adding tags to my prompt which don’t render but allow me to group similar images by their presence.

edit: after reading the pinecone.io article on sentence embeddings, it seems that performing a collapse on the CLIPped tokens via averaging or similar to get a whole sentence embedding is a bad approach. sentence-transformers provides a very simple library interface to abstracting all of this away, but as far as I know isn’t available in .net

RupertAvery commented 8 months ago

A lot of prompts I have are variations using Dynamic Prompts and end up being slightly different. In the prompts tab my initial approach was to do a Levenstein distance, but I thought something like a vector approch might be better, although I don't want to have to rely on an external tool (I don't know much about CLIP).

I just thought it might be possible to do some sort of sentence similarity without a full pretrained corpus or weights, (or maybe train from existing data) but I may be completely wrong.

I guess I'm thinking of how to cluster prompts based on tokens/words in them. And of course if possible implement it entirely in C#.

curiousjp commented 8 months ago

I'm sure something can be achieved, even without going down the road into embeddings, but the question will be about what meaning of similarity is useful for you.

If you want semantic similarity (i.e. that 'red hair' is closer to 'black hair' than to 'red car'), you might need to reach either for embeddings (particularly if weight is relevant to your idea of similarity) or for some form of in-prompt role tagging for your dynamic sections. But that might be overkill.

If you're happy to put semantic similarity to one side, and if you use a model that relies on booru style tagging, and you don't care about weights, an implementation of Levenshtein or Damerau-Levenshtein at the tag level rather than the character level might be worth considering. If you don't care about weights or prompt ordering, you could reduce the tags to a set and then use Jaccard similarity. I would probably start here - LINQ will give you easy access to the intersection and union operators you need for it if you can get over the hurdle of stripping your prompt down to just the sequence of tags.

These are direct comparisons between the two tag sequences, so there's no need to build up a list of every tag ever sighted in the database, but keep in mind that it also makes for a lot of comparisons as you add more items to the database (as does Cosine Similarity, but that is at least quite quick to do.) There are implementations of L/D-L (and other algorithms like longest-common-sequence) easily available in nuget, but as I understand it they all work either on a per-character basis or on fixed length ngrams, so you may need to write your own.

RupertAvery commented 8 months ago

So I looked at Jaccard index, and when I saw that it's based on sets, I had an idea.

I have a library called SparseBitsets, and bascially it encodes unique ids into a sparse bit array, which you can then perform AND/OR/NOT operations between bitsets.

So if you parse all the prompts (it's actually quite fast, just a very simple split on space and comma for now), then process each tags into a 1-hot encoding, you'll have a vocabulary of tags in your database.

Then you reprocess the prompts, this time looking up each tag in the vocabulary and assigning it's index to a bit in the bitset.

So now you have a set of all unique tags in each prompt, represented as a bitset.

Calculating the Intersection and Union of these is dead simple, it just calls bitwise and / or on the underlying bits.

The results are encouraging. And best of all, it's fast!

If you're interested, I pushed a branch named "jaccard-index". The code is in the TestHarness project. It will parse your prompts into bitsets and then it will loop, taking a random prompt from your list and then looking for similar prompts based the jaccard index.

The catch of course is that you have to have all the bitsets in memory to do a search, which is pretty small since a prompt has around 200+ tokens, which encode to around 25 bytes (I actually store it in a ulong, so 64 bits is the smallest storage).

curiousjp commented 8 months ago

The results are encouraging. And best of all, it's fast!

Yes, performance of the encoding section looks very good! When were you thinking of updating the bitset encodings? Perhaps when the user enters the prompt searching function?

You probably know this already, but linq has an .OrderByDescending() operator that lets you also sort the top five before iterating for display. Perhaps as a first pass to cutting out some of the weight punctuation in the positive prompt you could try something like this in Tokenize()? return Regex.Replace(text, @"(?:\s{2,})|(?:(?::\d*\.?\d*)?[\)\]\}])|(?:[\(\[\{])", "", RegexOptions.Multiline).Split(new char[] { ' ', ',' });

RupertAvery commented 8 months ago

I'm moving prompt search into the search page, so that we have the full thumbnail + preview instead of the separate view, but yes, for now the bitset encodings are being generated on the fly. Since they don't change, I'm planning to bulk-generate existing ones once on startup as a migration step, (or a user opt-in), and persist them to the DB and load them up on next startup. Then they will be generated each time an image is scanned in as it's quite fast.

The user should be able to rebuild it as necessary.

Thanks for the Regex!

RupertAvery commented 8 months ago

Using this algorithm as a plain search results doesn't usually give a good set of results (although there still may be a lot of bugs).

When the number of search tokens is smaller than the target, of course it's Jaccard index is going to be low, and instead a text or token search will yield better results.

But, still think it can be useful for matching or grouping dynamic prompts together, possibly as a preprocessing step. I'm just typing this mostly to keep the idea in mind. Searching against prompts will usually return individual prompts, so there will be a lot of results, but a lot of them will be similar. So perhaps we could use it to group together prompts so they appear as one "group of prompts", with all the associated images grouped under them.

curiousjp commented 8 months ago

When the number of search tokens is smaller than the target, of course it's Jaccard index is going to be low, and instead a text or token search will yield better results.

One possible alternative here is that instead of calculating the Jaccard, you can just calculate the size of the intersection of the two sets divided by the size of the search-term set. Perhaps as a second option.

It sounds like in some ways you could be working towards a clustering solution - while I'm not sure it would be a good fit for DT's existing UI, some people have done some interesting work on clustering their prompts along with another metric (like clip score) as a way to find "good prompts" - see, e.g. here.

curiousjp commented 5 months ago

Did you end up exploring this space further? I have recently been playing with cluster visualisations - I use a tagger to tag all the images under a folder, and then encode those into a big vector for each image. The vectors get dimensionally reduced using umap, and then clustered with hdbscan. Finally, I scatter plot them in a workbook, using a local http server shim to serve the image as part of the tooltip when you mouse over each datapoint. It works surprisingly well - the biggest irritation is that nobody seems to have an out of the box solution for displaying jupyter output cells full screen.

I'm not sure something like this would be a good fit for Diffusion.Toolkit, but I thought I'd mention it in case it gave you any ideas.