tolgabirdal / PHDimGeneralization

Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021.
MIT License
30 stars 2 forks source link

Persistent Homology Dimension for Embeddings #3

Open Amelie-Schreiber opened 11 months ago

Amelie-Schreiber commented 11 months ago

Hi, I am using part of your topology.py script to calculate the persistent homology dimension for the embeddings of a protein language model (ESM-2) and the dimension estimate show low error but seems to fluctuate quite a lot when running it multiple times on the same protein. Would it be beneficial to run multiple times and average? What other strategies might stabilize the estimates? Below is my current script:

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Load Model and Tokenizer
def load_model(model_name="facebook/esm2_t33_650M_UR50D"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

# Compute Embeddings
def get_embeddings(sequence, tokenizer, model):
    inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024, add_special_tokens=False)
    with torch.no_grad():
        outputs = model(**inputs)
    # Convert to numpy and ignore <cls> and <eos>
    return outputs.last_hidden_state.squeeze(0).cpu().numpy()

# Sampling Function
def sample_W(W, nSamples, isRandom=True):
    n = W.shape[0]
    if nSamples > n:
        raise ValueError(f"Requested more samples ({nSamples}) than available points ({n}).")
    random_indices = np.random.choice(n, size=nSamples, replace=False)
    return W[random_indices]

# Fractal Dimension Estimation Function (CPU-based)
def calculate_ph_dim(W, min_points=15, max_points=1024, point_jump=5, h_dim=0, print_error=False):
    from ripser import ripser

    # Ensure that max_points do not exceed the number of points in W
    max_points = min(max_points, W.shape[0])

    test_n = range(min_points, max_points, point_jump)
    lengths = []
    for n in test_n:
        diagrams = ripser(sample_W(W, n))['dgms']

        if len(diagrams) > h_dim:
            d = diagrams[h_dim]
            d = d[d[:, 1] < np.inf]
            lengths.append((d[:, 1] - d[:, 0]).sum())
        else:
            lengths.append(0.0)
    lengths = np.array(lengths)

    x = np.log(np.array(list(test_n)))
    y = np.log(lengths)
    N = len(x)
    m = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x ** 2).sum() - x.sum() ** 2)
    b = y.mean() - m * x.mean()

    error = ((y - (m * x + b)) ** 2).mean()

    if print_error:
        print(f"Ph Dimension Calculation has an approximate error of: {error}.")
    return 1 / (1 - m)

# Main Execution
if __name__ == "__main__":
    # Example protein sequence
    protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"

    # Load model and tokenizer
    tokenizer, model = load_model()

    # Compute embeddings
    embeddings = get_embeddings(protein_sequence, tokenizer, model)

    # Ensure max_points in calculate_ph_dim do not exceed the number of points in embeddings
    max_points = min(1024, embeddings.shape[0])

    # Estimate the fractal dimension
    ph_dimension = calculate_ph_dim(embeddings, max_points=max_points, print_error=True)
    print(f"Estimated Persistent Homology Dimension: {ph_dimension}")

Averaging does seem to stabilize some, but not as much as I would like. Any feedback on using your code for this purpose would be greatly appreciated!

tolgabirdal commented 7 months ago

@Amelie-Schreiber This is a fascinating application! Sorry for my late reply. Averaging makes sense (or robust averaging such as median). I also tried robust linear regression. So, to compute (m,b) you can leverage RANSAC for instance: https://scikit-learn.org/stable/auto_examples/linear_model/plot_robust_fit.html. I had some experiments in my paper, which did not show a significant difference. However, your problem is completely different. Potentially it might help.

Having a large amount of samples could also help. Usually large dimension, small sample size can lead to unstability.

Keep me posted, I'm curious. :)