facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

[Question] Similarity search using the embeddings of the training dataset (uniparc) #23

Closed gokceneraslan closed 3 years ago

gokceneraslan commented 3 years ago

Hi,

For people working in the field of protein science, it'd be useful to find the sequences/structures in uniparc that are similar to a given sequence in the esm embedding space, as a great alternative to the existing protein sequence-based search tools and methods.

Is the easiest way to do that 1) download uniparc 2) get esm-1b embeddings 3) build a kNN index via Faiss or pynndescent, or are you planning to release a script and/or a kNN/faiss index to facilitate that somehow?

I can imagine that you have been already using a Faiss instance to do that internally, but the question is whether you'd like to release it or not 😄

Cheers.

tomsercu commented 3 years ago

Hi Gökçen, Your approach sounds right, but we haven't spent much time on this. The main effort will be around parallelizing the embedding generation, which will be specific to the cluster you're on. @ebetica may be able to share a snippet for generating a faiss index. Best, Tom

ebetica commented 3 years ago

Here's a snippet:

def build_index(
    data_path: str,
    num_clusters: int,
    test: bool = False,
    rebuild: bool = False,
    pca=64,
) -> Tuple[List[str], faiss.Index]:
    cache_fn = f"{data_path}/cache.faiss"
    embfiles = list(sorted(glob(f"{data_path}/embs.*.pt")))
    seqfiles = list(sorted(glob(f"{data_path}/seqs.*.txt")))
    should_load = not rebuild and path.exists(cache_fn)

    if test:
        embfiles = embfiles[:2]
        seqfiles = seqfiles[:2]

    mat = load(embfiles[0])
    d = mat.shape[1]
    fits_into_memory = mat.size * len(embfiles) * 4 < 200e9
    # PCAR64 means to do a PCA to 64 dimensions, this should get our dataset to fit into RAM
    # Middle argument is recommended for many vectors
    # Last argument is scalar quantization from 4 bytes to 1
    if should_load:
        print(f"Loading cached index from {cache_fn}...")
        index = faiss.read_index(cache_fn)
    elif test:
        index = faiss.index_factory(d, f"PCAR{pca},IVF32_HNSW32,SQ8")
    elif fits_into_memory:
        index = faiss.IndexFlatIP(d)
    else:
        index = faiss.index_factory(d, f"PCAR{pca},IVF{num_clusters}_HNSW32,SQ8")

    if not should_load:
        print("| Loading training set for FAISS...")
        mats = []
        total_train = 0
        with tqdm(total=num_clusters * 40) as pb:
            for fn in embfiles:
                mats.append(load(fn))
                total_train += mats[-1].shape[0]
                pb.update(mats[-1].shape[0])
                if total_train >= num_clusters * 40:
                    break
        print("| Training FAISS quantization scheme...")
        t = time.time()
        index.train(np.concatenate(mats))
        print(f"| Done in {time.time() - t} seconds")

    keys = []

    print("| Adding data to FAISS...")
    for i, (fn, sfn) in tqdm(enumerate(zip(embfiles, seqfiles)), total=len(embfiles)):
        if not should_load:
            mat = load(fn)
            index.add(mat)

        with open(sfn, "r") as f:
            keys += [x.strip() for x in f.readlines()]

    D, I = index.search(mat[:5], 2)  # sanity check
    print("Sanity check: 2-NN of first 5 elements in your data")
    print(D)
    print(I)
    print("\n".join(keys[i] for i in I[:, 0]))

    if not should_load:
        faiss.write_index(index, cache_fn)

    return keys, index

Say you dump your embeddings in {data_path}/embs12345.pt and sequences in {data_path}/seqs12345.txt. You can use this function to load them all and combine them into a FAISS index. Check the FAISS documentation for the right number to select for num_clusters. Pick your PCA dimension depending on how much memory you have available. I'm still experimenting with this, so sorry if the code does not work perfectly.

tomsercu commented 3 years ago

Thx Zeming! Let me close this now but happy to help out with any follow ups!

gokceneraslan commented 3 years ago

Thank you so much both!