redis / redis-py

Redis Python client
MIT License
12.68k stars 2.53k forks source link

Add scorer support for aggregations. Allowing for BM25 / Vector hybrid search. #3408

Open rbs333 opened 1 month ago

rbs333 commented 1 month ago

Description: Currently there is no way to set the scorer in an aggregate request. This makes running hybrid BM25 / queries impractical. This PR adds support so that it can be executed.

Example test added for hybrid query:

@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(client):
    client.ft().create_index(
        (
            TextField("name", sortable=True, weight=5.0),
            TextField("description", sortable=True, weight=5.0),
            VectorField(
                "vector",
                "HNSW",
                {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
            ),
        )
    )

    client.hset(
        "doc1",
        mapping={
            "name": "cat book",
            "description": "a book about cats",
            "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
        },
    )
    client.hset(
        "doc2",
        mapping={
            "name": "dog book",
            "description": "a book about dogs",
            "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
        },
    )

    query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
    req = (
        aggregations.AggregateRequest(query_string)
        .scorer("BM25")
        .add_scores()
        .apply(hybrid_score="@__score + @dist")
        .load("*")
        .dialect(4)
    )

    res = (
        client.ft()
        .aggregate(
            req,
            query_params={
                "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
            },
        )
        .rows[0]
    )

    assert len(res) == 6
    assert b"hybrid_score" in res
    assert b"__score" in res
    assert b"__dist" in res
    assert float(res[1]) + float(res[3]) == float(res[5])

PR: https://github.com/redis/redis-py/pull/3409