asg017 / sqlite-vss

A SQLite extension for efficient vector search, based on Faiss!
MIT License
1.59k stars 59 forks source link

add metric type for vector comparison #75

Closed dleviminzi closed 11 months ago

dleviminzi commented 12 months ago

Decided to take a stab at #6. Tests passed and also tested some more in my own go project. Let me know if there is anything you'd like changed.

I was wondering whether we want to skip checking for the metric type option if the factory option is not provided. As it is, it will still check without a factory option provided. I can tweak it not to though if that is preferred.

asg017 commented 11 months ago

This looks great @dleviminzi , thanks for contributing!

Re checking if the factory arg is passed: How it is currently is fine. You can still use the default factory string and still want a different metric_type. The default factory string Flat,IDMap2 can be used with any of the different metric types.

Gonna test this PR a bit and merge if I don't find anything. I'll most likely add a few tests and push a few commits before making the next release, so this should appear in the upcoming v0.1.1 release!

The only thing I can think of now is that metric_type=Lp has a metric_arg parameter to 'set the power'. I think we'll want to include an extra parameter for this. A few ideas:

-- Option 1: a new `metric_arg` parameter
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp metric_arg=4
);

-- Option 2: include as a quasi-contraint on Lp directly
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp(4)
);

Not required for this PR though - I'm down to just comment out the Lp option for now and re-add it later.

I also can't get the jensenshannon option to work locally (doesn't seem to actually insert data into the index), but I might be doing something wrong

dleviminzi commented 11 months ago

This looks great @dleviminzi , thanks for contributing!

Re checking if the factory arg is passed: How it is currently is fine. You can still use the default factory string and still want a different metric_type. The default factory string Flat,IDMap2 can be used with any of the different metric types.

Gonna test this PR a bit and merge if I don't find anything. I'll most likely add a few tests and push a few commits before making the next release, so this should appear in the upcoming v0.1.1 release!

The only thing I can think of now is that metric_type=Lp has a metric_arg parameter to 'set the power'. I think we'll want to include an extra parameter for this. A few ideas:

-- Option 1: a new `metric_arg` parameter
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp metric_arg=4
);

-- Option 2: include as a quasi-contraint on Lp directly
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp(4)
);

Not required for this PR though - I'm down to just comment out the Lp option for now and re-add it later.

I also can't get the jensenshannon option to work locally (doesn't seem to actually insert data into the index), but I might be doing something wrong

Awesome, I'll remove that comment then.

Aesthetically, I like adding the metric arg as a quasi-constraint on Lp more than adding another parameter.

I'll remove that comment and see about adding some tests/figuring out if I also have issues with jensenshannon. I should be able to get working on that stuff after work in a couple hours.

dleviminzi commented 11 months ago

The only thing I can think of now is that metric_type=Lp has a metric_arg parameter to 'set the power'. I think we'll want to include an extra parameter for this. A few ideas:

-- Option 1: a new `metric_arg` parameter
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp metric_arg=4
);

-- Option 2: include as a quasi-contraint on Lp directly
create virtual table vss_demo using vss0(
  embedding(376) metric_type=Lp(4)
);

Not required for this PR though - I'm down to just comment out the Lp option for now and re-add it later.

So I tried to get his working and have decided to give up for now. The branch where I made my attempt is here: https://github.com/dleviminzi/sqlite-vss/tree/dleviminzi/lp-metric-arg.

I verified that parsing the metric argument and passing it to the Faiss index worked. However, it doesn't seem to actually get used. I figure I must be missing something about how this is meant to be done.

I also can't get the jensenshannon option to work locally (doesn't seem to actually insert data into the index), but I might be doing something wrong

I wrote a test for this one and cosine similarity. Both seem to working correctly as far as I can tell. Might be missing something though, lmk. (I'll go back and write some more tests for the other metric types in the morning tomorrow).

asg017 commented 11 months ago

Here's a test you can add that tests all the different metric_types (except Lp), feel free to add if you have time (otherwise I'll add after merging)

diff --git a/tests/test-loadable.py b/tests/test-loadable.py
index c6a65c0..e4a6e87 100644
--- a/tests/test-loadable.py
+++ b/tests/test-loadable.py
@@ -534,6 +534,65 @@ class TestVss(unittest.TestCase):
       [{'distance': 2.0, 'rowid': 1000}]
     )

+  def test_vss0_metric_type(self):
+    cur = db.cursor()
+    execute_all(
+      cur,
+      '''create virtual table vss_mts using vss0(
+        ip(2) metric_type=INNER_PRODUCT,
+        l1(2) metric_type=L1,
+        l2(2) metric_type=L2,
+        linf(2) metric_type=Linf,
+        -- lp(2) metric_type=Lp,
+        canberra(2) metric_type=Canberra,
+        braycurtis(2) metric_type=BrayCurtis,
+        jensenshannon(2) metric_type=JensenShannon
+      )'''
+    )
+    idxs = list(map(lambda row: row[0], db.execute("select idx from vss_mts_index").fetchall()))
+
+    # ensure all the indexes are IDMap2 ("IxM2")
+    for idx in idxs:
+      idx_type = idx[0:4]
+      self.assertEqual(idx_type, b"IxM2")
+
+    # manually tested until i ended up at 33 ¯\_(ツ)_/¯
+    METRIC_TYPE_OFFSET = 33
+
+    # values should match https://github.com/facebookresearch/faiss/blob/43d86e30736ede853c384b24667fc3ab897d6ba9/faiss/MetricType.h#L22
+    self.assertEqual(idxs[0][METRIC_TYPE_OFFSET], 0) # ip
+    self.assertEqual(idxs[1][METRIC_TYPE_OFFSET], 2) # l1
+    self.assertEqual(idxs[2][METRIC_TYPE_OFFSET], 1) # l2
+    self.assertEqual(idxs[3][METRIC_TYPE_OFFSET], 3) # linf
+    #self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 4) # lp
+    self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 20) # canberra
+    self.assertEqual(idxs[5][METRIC_TYPE_OFFSET], 21) # braycurtis
+    self.assertEqual(idxs[6][METRIC_TYPE_OFFSET], 22) # jensenshannon
+
+
+    db.execute(
+      "insert into vss_mts(rowid, ip, l1, l2, linf, /*lp,*/ canberra, braycurtis, jensenshannon) values (1, ?1,?1,?1,?1, /*?1,*/ ?1,?1,?1)",
+      ["[4,1]"]
+    )
+    db.commit()
+
+    def distance_of(metric_type, query):
+      return db.execute(
+       f"select distance from vss_mts where vss_search({metric_type}, vss_search_params(?1, 1))",
+       [query]
+      ).fetchone()[0]
+
+    self.assertEqual(distance_of("ip",          "[0,0]"), 0.0)
+    self.assertEqual(distance_of("l1",          "[0,0]"), 5.0)
+    self.assertEqual(distance_of("l2",          "[0,0]"), 17.0)
+    self.assertEqual(distance_of("linf",        "[0,0]"), 4.0)
+    #self.assertEqual(distance_of("lp",          "[0,0]"), 2.0)
+    self.assertEqual(distance_of("canberra",    "[0,0]"), 2.0)
+    self.assertEqual(distance_of("braycurtis",  "[0,0]"), 1.0)
+    #self.assertEqual(distance_of("jensenshannon", "[0,0]"), 2.0)
+
+    self.assertEqual(distance_of("ip", "[2,2]"), 10.0)
+
 VECTOR_FUNCTIONS = [
   'vector0',
   'vector_debug',
dleviminzi commented 11 months ago

Ahh that is much more succinct than what I was going to do. Looks good, I'll add it right now.

dleviminzi commented 11 months ago

After doing some digging, I think I figured out what is going on with the JensenShannon distance. There are two things to note (I added this to the docs as well). The first is that the FAISS implementation of JensenShannon distance assumes L1 normalized input - a valid probability distribution. The second thing to note is that FAISS actually implements the JensenShannon divergence, not the distance. I don't know if that was intentional.

edit: I posted an issue with FAISS and I think they will be fixing it so that it is the distance.

asg017 commented 11 months ago

Thank you @dleviminzi ! Will be merging this and cutting a v0.1.1 release shortly.