Closed gokceneraslan closed 3 years ago
Hi Gökçen, Your approach sounds right, but we haven't spent much time on this. The main effort will be around parallelizing the embedding generation, which will be specific to the cluster you're on. @ebetica may be able to share a snippet for generating a faiss index. Best, Tom
Here's a snippet:
def build_index(
data_path: str,
num_clusters: int,
test: bool = False,
rebuild: bool = False,
pca=64,
) -> Tuple[List[str], faiss.Index]:
cache_fn = f"{data_path}/cache.faiss"
embfiles = list(sorted(glob(f"{data_path}/embs.*.pt")))
seqfiles = list(sorted(glob(f"{data_path}/seqs.*.txt")))
should_load = not rebuild and path.exists(cache_fn)
if test:
embfiles = embfiles[:2]
seqfiles = seqfiles[:2]
mat = load(embfiles[0])
d = mat.shape[1]
fits_into_memory = mat.size * len(embfiles) * 4 < 200e9
# PCAR64 means to do a PCA to 64 dimensions, this should get our dataset to fit into RAM
# Middle argument is recommended for many vectors
# Last argument is scalar quantization from 4 bytes to 1
if should_load:
print(f"Loading cached index from {cache_fn}...")
index = faiss.read_index(cache_fn)
elif test:
index = faiss.index_factory(d, f"PCAR{pca},IVF32_HNSW32,SQ8")
elif fits_into_memory:
index = faiss.IndexFlatIP(d)
else:
index = faiss.index_factory(d, f"PCAR{pca},IVF{num_clusters}_HNSW32,SQ8")
if not should_load:
print("| Loading training set for FAISS...")
mats = []
total_train = 0
with tqdm(total=num_clusters * 40) as pb:
for fn in embfiles:
mats.append(load(fn))
total_train += mats[-1].shape[0]
pb.update(mats[-1].shape[0])
if total_train >= num_clusters * 40:
break
print("| Training FAISS quantization scheme...")
t = time.time()
index.train(np.concatenate(mats))
print(f"| Done in {time.time() - t} seconds")
keys = []
print("| Adding data to FAISS...")
for i, (fn, sfn) in tqdm(enumerate(zip(embfiles, seqfiles)), total=len(embfiles)):
if not should_load:
mat = load(fn)
index.add(mat)
with open(sfn, "r") as f:
keys += [x.strip() for x in f.readlines()]
D, I = index.search(mat[:5], 2) # sanity check
print("Sanity check: 2-NN of first 5 elements in your data")
print(D)
print(I)
print("\n".join(keys[i] for i in I[:, 0]))
if not should_load:
faiss.write_index(index, cache_fn)
return keys, index
Say you dump your embeddings in {data_path}/embs12345.pt
and sequences in {data_path}/seqs12345.txt
. You can use this function to load them all and combine them into a FAISS index. Check the FAISS documentation for the right number to select for num_clusters. Pick your PCA dimension depending on how much memory you have available. I'm still experimenting with this, so sorry if the code does not work perfectly.
Thx Zeming! Let me close this now but happy to help out with any follow ups!
Thank you so much both!
Hi,
For people working in the field of protein science, it'd be useful to find the sequences/structures in uniparc that are similar to a given sequence in the esm embedding space, as a great alternative to the existing protein sequence-based search tools and methods.
Is the easiest way to do that 1) download uniparc 2) get esm-1b embeddings 3) build a kNN index via Faiss or pynndescent, or are you planning to release a script and/or a kNN/faiss index to facilitate that somehow?
I can imagine that you have been already using a Faiss instance to do that internally, but the question is whether you'd like to release it or not 😄
Cheers.