OML-Team / open-metric-learning

Metric learning and retrieval pipelines, models and zoo.
https://open-metric-learning.readthedocs.io/en/latest/index.html
Apache License 2.0
884 stars 61 forks source link

Implement Multi-query processing #511

Open AlekseySh opened 8 months ago

AlekseySh commented 8 months ago

The concept is that a query involves multiple objects instead of just one. We aim to retrieve results for all these objects simultaneously. A straightforward approach is to use frequency voting:

As the result, we should have an example similar to "Using a trained model for retrieval" (https://github.com/OML-Team/open-metric-learning?tab=readme-ov-file#examples)

VSXV commented 8 months ago

Hope I understood correctly. My variante.

AlekseySh commented 8 months ago

responded in DM :)

AlekseySh commented 7 months ago

waiting for https://github.com/OML-Team/open-metric-learning/issues/522

AlekseySh commented 7 months ago

EXAMPLE

multi query

AlekseySh commented 6 months ago

DRAFT

from collections import defaultdict

import numpy as np
import torch
from torch import FloatTensor, LongTensor

from oml.retrieval import RetrievalResults

rr = RetrievalResults(
    distances=[
        FloatTensor([0.1, 0.3, 0.6, 0.9]),
        FloatTensor([0.5, 0.8]),
        FloatTensor([0.1, 0.2]),
        FloatTensor([]),
    ],
    retrieved_ids=[
        LongTensor([0, 1, 2, 3]),
        LongTensor([4, 2]),
        LongTensor([10, 20]),
        LongTensor([])
    ],
    gt_ids=[
        LongTensor([0, 2, 50]),
        LongTensor([0, 2, 50]),  # todo: it may be not consisted
        LongTensor([10, 30]),
        LongTensor([50])
    ]
)

query_groups = [[0, 1], [2], [3]]

rr_expected = RetrievalResults(
    distances=[
        FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
        FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
        FloatTensor([0.1, 0.2]),
        FloatTensor([]),
    ],
    retrieved_ids=[
        LongTensor([0, 1, 4, 2, 3]),
        LongTensor([0, 1, 4, 2, 3]),
        LongTensor([10, 20]),
        LongTensor([])
    ],
    gt_ids=[
        LongTensor([0, 2, 50]),
        LongTensor([0, 2, 50]),
        LongTensor([10, 30]),
        LongTensor([50])
    ]
)

distances_upd, retrieved_ids_upd = dict(), dict()
for group in query_groups:
    group_lens = [len(rr.retrieved_ids[ig]) for ig in group]
    if set(group_lens) == {0}:
        for ig in group:
            distances_upd[ig] = FloatTensor([])
            retrieved_ids_upd[ig] = LongTensor([])

    else:
        dist_group = torch.concat([rr.distances[ig] for ig in group])
        ri_group = torch.concat([rr.retrieved_ids[ig] for ig in group])
        gt_ids = torch.concat([rr.gt_ids[ig] for ig in group])

        ri2dist = defaultdict(list)
        for d, ri in zip(dist_group, ri_group):
            ri2dist[int(ri)].append(float(d))

        ri_dist = [(ri, float(np.mean(d))) for ri, d in ri2dist.items()]
        ri_dist = sorted(ri_dist, key=lambda x: x[1], reverse=False)
        ri_upd, dist_upd = zip(*ri_dist)

        for ig in group:
            distances_upd[ig] = FloatTensor(dist_upd)
            retrieved_ids_upd[ig] = LongTensor(ri_upd)

distances_upd_final = []
retrieved_ids_upd_final = []
for iq in range(len(rr.retrieved_ids)):
    distances_upd_final.append(distances_upd[iq])
    retrieved_ids_upd_final.append(retrieved_ids_upd[iq])

rr_produced = RetrievalResults(distances=distances_upd_final, retrieved_ids=retrieved_ids_upd_final, gt_ids=rr.gt_ids)

print(rr_expected)
print(rr_produced)