rom1504 / clip-retrieval

Easily compute clip embeddings and build a clip retrieval system with them
https://rom1504.github.io/clip-retrieval/
MIT License
2.41k stars 209 forks source link

Do notebook to explore assignment and centroids of index #196

Open rom1504 opened 2 years ago

rom1504 commented 2 years ago

https://github.com/facebookresearch/faiss/issues/2266 for centroids

https://github.com/rom1504/clip-retrieval/blob/main/clip_retrieval/ivf_metadata_ordering.py#L46

rom1504 commented 2 years ago

get 2 files from https://huggingface.co/datasets/laion/laion5B-index/tree/main/image.index

load index:

import faiss
index = faiss.read_index("populated.index", faiss.IO_FLAG_ONDISK_SAME_DIR)

centroids:

ind = faiss.extract_index_ivf(index)
centroids = ind.quantizer.reconstruct_n(0, ind.nlist)
centroids = index.chain.at(0).reverse_transform(centroids)

(need to reverse the transformation to get back to initial space)

assignments, something like:

    il = faiss.extract_index_ivf(index).invlists
    d = np.ones((index.ntotal,), "int64")
    begin_list = []
    current_begin = 0
    for i in tqdm(range(il.nlist)):
        begin_list.append(current_begin)
        ids = il.get_ids(i)
        list_size = il.list_size(int(i))
        items = faiss.rev_swig_ptr(ids, list_size)
        new_ids = range(current_begin, current_begin + list_size)
        d.put(np.array(items, "int"), np.array(new_ids, "int"))
        il.release_ids(ids=ids, list_no=i)
        current_begin += list_size
rom1504 commented 2 years ago
l = [il.list_size(i) for i in range(il.nlist)]

get all cluster list sizes

rom1504 commented 2 years ago

get items of the largest cluster:

def id_to_items(i):
  ids = il.get_ids(i)
  list_size = il.list_size(int(i))
  items = faiss.rev_swig_ptr(ids, list_size)
  items = np.array(items, "int")
  return items

l = [il.list_size(i) for i in range(il.nlist)]
i = int(np.argmax(l))
ids = id_to_items(l)
rom1504 commented 2 years ago

get urls of some metadata (paste in browser console)

fetch("https://knn5.laion.ai/metadata", {
  "headers": {
    "accept": "*/*",
    "accept-language": "en,fr-FR;q=0.9,fr;q=0.8,en-US;q=0.7,zh-CN;q=0.6,zh;q=0.5,zh-TW;q=0.4",
    "content-type": "text/plain;charset=UTF-8",
    "sec-ch-ua": "\"Chromium\";v=\"106\", \"Google Chrome\";v=\"106\", \"Not;A=Brand\";v=\"99\"",
    "sec-ch-ua-mobile": "?0",
    "sec-ch-ua-platform": "\"Linux\"",
    "sec-fetch-dest": "empty",
    "sec-fetch-mode": "cors",
    "sec-fetch-site": "cross-site"
  },
  "referrer": "https://rom1504.github.io/",
  "referrerPolicy": "strict-origin-when-cross-origin",
  "body": "{\"ids\":[4832458793,4832458822],\"indice_name\":\"laion5B\"}",
  "method": "POST",
  "mode": "cors",
  "credentials": "omit"
}).then(a => a.json()).then(a => console.log(a.map(e => e.metadata)))
rom1504 commented 2 years ago

=> plug that into a part of the ui

yzou2 commented 1 year ago

get 2 files from https://huggingface.co/datasets/laion/laion5B-index/tree/main/image.index

load index:

import faiss
index = faiss.read_index("populated.index", faiss.IO_FLAG_ONDISK_SAME_DIR)

centroids:

ind = faiss.extract_index_ivf(index)
centroids = ind.quantizer.reconstruct_n(0, ind.nlist)
centroids = index.chain.at(0).reverse_transform(centroids)

(need to reverse the transformation to get back to initial space)

assignments, something like:

    il = faiss.extract_index_ivf(index).invlists
    d = np.ones((index.ntotal,), "int64")
    begin_list = []
    current_begin = 0
    for i in tqdm(range(il.nlist)):
        begin_list.append(current_begin)
        ids = il.get_ids(i)
        list_size = il.list_size(int(i))
        items = faiss.rev_swig_ptr(ids, list_size)
        new_ids = range(current_begin, current_begin + list_size)
        d.put(np.array(items, "int"), np.array(new_ids, "int"))
        il.release_ids(ids=ids, list_no=i)
        current_begin += list_size

@rom1504 , thanks for getting the code for getting the clusters. I am trying to understand the code. What does the function faiss.rev_swig_ptr(ids, list_size) do? I don't find a good document on this function on faiss. So wish you can provide some help. Thanks!