KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.93k stars 658 forks source link

faiss search #491

Open wangyijun91 opened 2 years ago

wangyijun91 commented 2 years ago

pytorch_metric_learning.utils.inference

in function try_gpu(line258), weather gpu_index or cpu_index, both call function add_to_index_and_search, but this function only use cpu for faiss search

def add_to_index_and_search(index, query, reference, k): if reference is not None: index.add(reference.float().cpu()) return index.search(query.float().cpu(), k)

KevinMusgrave commented 2 years ago

When the faiss index is on multiple GPUs, then the inputs have to be on CPU (see https://github.com/facebookresearch/faiss/issues/1997). Inside the function, faiss moves the tensors back to the GPU.

When the index is on a single GPU, then the input tensor can also be on GPU.

Maybe I can add some code to move the tensor to CPU only if the index is on multiple GPUs.