Open AlekseySh opened 8 months ago
Hope I understood correctly. My variante.
responded in DM :)
EXAMPLE
DRAFT
from collections import defaultdict
import numpy as np
import torch
from torch import FloatTensor, LongTensor
from oml.retrieval import RetrievalResults
rr = RetrievalResults(
distances=[
FloatTensor([0.1, 0.3, 0.6, 0.9]),
FloatTensor([0.5, 0.8]),
FloatTensor([0.1, 0.2]),
FloatTensor([]),
],
retrieved_ids=[
LongTensor([0, 1, 2, 3]),
LongTensor([4, 2]),
LongTensor([10, 20]),
LongTensor([])
],
gt_ids=[
LongTensor([0, 2, 50]),
LongTensor([0, 2, 50]), # todo: it may be not consisted
LongTensor([10, 30]),
LongTensor([50])
]
)
query_groups = [[0, 1], [2], [3]]
rr_expected = RetrievalResults(
distances=[
FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
FloatTensor([0.1, 0.2]),
FloatTensor([]),
],
retrieved_ids=[
LongTensor([0, 1, 4, 2, 3]),
LongTensor([0, 1, 4, 2, 3]),
LongTensor([10, 20]),
LongTensor([])
],
gt_ids=[
LongTensor([0, 2, 50]),
LongTensor([0, 2, 50]),
LongTensor([10, 30]),
LongTensor([50])
]
)
distances_upd, retrieved_ids_upd = dict(), dict()
for group in query_groups:
group_lens = [len(rr.retrieved_ids[ig]) for ig in group]
if set(group_lens) == {0}:
for ig in group:
distances_upd[ig] = FloatTensor([])
retrieved_ids_upd[ig] = LongTensor([])
else:
dist_group = torch.concat([rr.distances[ig] for ig in group])
ri_group = torch.concat([rr.retrieved_ids[ig] for ig in group])
gt_ids = torch.concat([rr.gt_ids[ig] for ig in group])
ri2dist = defaultdict(list)
for d, ri in zip(dist_group, ri_group):
ri2dist[int(ri)].append(float(d))
ri_dist = [(ri, float(np.mean(d))) for ri, d in ri2dist.items()]
ri_dist = sorted(ri_dist, key=lambda x: x[1], reverse=False)
ri_upd, dist_upd = zip(*ri_dist)
for ig in group:
distances_upd[ig] = FloatTensor(dist_upd)
retrieved_ids_upd[ig] = LongTensor(ri_upd)
distances_upd_final = []
retrieved_ids_upd_final = []
for iq in range(len(rr.retrieved_ids)):
distances_upd_final.append(distances_upd[iq])
retrieved_ids_upd_final.append(retrieved_ids_upd[iq])
rr_produced = RetrievalResults(distances=distances_upd_final, retrieved_ids=retrieved_ids_upd_final, gt_ids=rr.gt_ids)
print(rr_expected)
print(rr_produced)
The concept is that a query involves multiple objects instead of just one. We aim to retrieve results for all these objects simultaneously. A straightforward approach is to use frequency voting:
As the result, we should have an example similar to "Using a trained model for retrieval" (https://github.com/OML-Team/open-metric-learning?tab=readme-ov-file#examples)