jbellis / jvector

JVector: the most advanced embedded vector search engine
Apache License 2.0
1.5k stars 112 forks source link

RandomAccessScoreProvider with MapRandomAccessVectorValues + non-sequential IDs produces wrong centroid #354

Open vbekiaris opened 2 months ago

vbekiaris commented 2 months ago

In 97e523c306ae42c3e963484e320fa1c7432b5250 approximateCentroid() implementation for the BuildScoreProvider returned from BuildScoreProvider.randomAccessScoreProvider() was updated to allow for non-sequential node IDs.

However the iteration only takes into account nodes with ID < ravv.size(). This means that if there are actually "holes" in the ID sequence (e.g. add 100 nodes in a MapRandomAccessVectorValues, then remove 10 starting from 0), then some nodes (those with ID >= 90 in the example) will not be taken into account while calculating the centroid.

A fix would probably require changing RandomAccessVectorValues API to expose an iterator or the highest nodeId that is set (or something similar).

Test that demonstrates the issue:

    @Test
    void testRandomAccessScoreProvider() {
        Map<Integer, VectorFloat<?>> chm = new ConcurrentHashMap<>();
        // nodeId's 0..49 have vector [0, 1], 49..99 are [1, 0], so centroid is [0.5, 0.5]
        for (int i = 0; i < 100; i++) {
            chm.put(i, createVector(i < 50));
        }
        RandomAccessVectorValues ravv = new MapRandomAccessVectorValues(chm, 2);
        var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.COSINE);
        float[] centroid = (float[]) bsp.approximateCentroid().get();
        Assertions.assertArrayEquals(new float[] {0.5f, 0.5f}, centroid);
        // remove even nodeId's. So we have just 50 nodeIds left now and they are no longer sequential
        // centroid however should stay the same at [0.5, 0.5], since we have 25 [0, 1] and 25 [1, 0] vectors
        for (int i = 0; i < 100; i = i+2) {
            chm.remove(i);
        }
        centroid = (float[]) bsp.approximateCentroid().get();
        // fails - calculated centroid is [0, 0.5] because centroid calculation only took into account node IDs < 50
        Assertions.assertArrayEquals(new float[] {0.5f, 0.5f}, centroid);
    }

    VectorFloat<?> createVector(boolean vertical) {
        VectorFloat<?> vectorFloat = VectorizationProvider.getInstance().getVectorTypeSupport().createFloatVector(2);
        if (vertical) {
            vectorFloat.set(0, 0);
            vectorFloat.set(1, 1);
        } else {
            vectorFloat.set(0, 1);
            vectorFloat.set(1, 0);
        }
        return vectorFloat;
    }
jbellis commented 2 months ago

Thanks for the report and the test!

IMO this is working as designed, the implicit contract of RAVV is that it should give a valid vector for ordinals from 0..size(). In other words JVector is designed to support "holes" in a graph, but not in RAVV.

Apparently there are 31 usages of size() so reviewing those to introduce getIdUpperBound the way we did for GraphIndex to accommodate holes isn't unreasonable, but I'm a bit skeptical that it's necessary. Can you share more about your use case?

(I'm also happy to eliminate the source of confusion by deleting MapRAVV. I thought it was going to be useful for Cassandra but we ended up not using it after all.)

shultseva commented 2 months ago

Hi, I'd like to continue this discussion.

In our scenario, we have a collection where vectors can be added and removed, which makes it easy to create gaps in the RAVV.

Could you clarify how the library is intended to be used in this context? Should deletion be avoided, or should deleted node identifiers be reused?

jbellis commented 2 months ago

All of the on-heap data structures are designed around the principle that node ids are mostly contiguous (see DenseIntMap in particular), and the on-disk structure assumes they are entirely contiguous. Additionally, removeDeletedNodes cannot be called safely while other modifications are in flight. (See https://github.com/jbellis/jvector/issues/272.)

I think this would work for you:

  1. Periodically pause mutations while running removeDeletedNodes
  2. Once complete, you can re-use deleted node ids, and overwrite them in your MapRAVV.
  3. Never remove from the MapRAVV, only overwrite when you reuse the id

I would support modifying removeDeletedNodes to return a BitSet of removed IDs so user code doesn't have to track deletions a second time.