OpenBioML / protein-lm-scaling

Other
54 stars 15 forks source link

Add curriculum learning strategy #39

Open pascalnotin opened 10 months ago

pascalnotin commented 10 months ago

Adapt sequence sampling throughout training so that we first train on "simpler sequences" and then on more and more "complexe" sequences. We discussed three different types of strategies depending on how complexity is defined:

  1. Sequence length: shorter sequences first
  2. Perplexity: lower perplexity first (eg., based on (pseudo-)perplexity from a pretrained model such as ESM2, RITA, Tranception or Progen --
  3. Structural similarity to PDB: lower pLDDT first (eg., with AF2 or ESMFold -- later might be easier to scale computationally and would handle all sequences, even those with no homolog)
talkhanz commented 10 months ago

sounds cool. i can give it a shot!

talkhanz commented 10 months ago

/take

talkhanz commented 10 months ago

I assume we are going to create a custom DataLoader for this?

Leo-T-Zang commented 10 months ago

I am also interested to help for this problem. Shall we work together @talkhanz ?

Leo-T-Zang commented 10 months ago

/take

pascalnotin commented 9 months ago

Thank you @talkhanz and @Leo-T-Zang ! Updated the issue with various possible curriculum schemes we discussed today. Will let you decide which scheme(s) you prefer working on :)

talkhanz commented 9 months ago

I am also interested to help for this problem. Shall we work together @talkhanz ?

Hey @Leo-T-Zang sounds good . I can work on the sequence length wise scheme and maybe you can pick any of the other two. We should probably agree on an approach to avoid merge conflicts. Perhaps a main CurriculumLearning class with strategy as an a attribute to decide which scheme is applied underhood?

Leo-T-Zang commented 9 months ago

Sure, can we talk over discord? What is your username?

talkhanz commented 9 months ago

Sure, can we talk over discord? What is your username?

same as my github

talkhanz commented 9 months ago

@pascalnotin just to update, we've finalized our approach and a PR (for strategy=sequence_length) should emerge quickly once I figure out a good unit test for this

Muedi commented 9 months ago

Hi,

to follow up on my suggestion in the talk:

We could use some entropy based (or generally info-theory) metric to assess a given sequence. I asked my colleagues who work with this stuff, so lets see if they have a cool Idea, or want to join or effort here.

I just quickly skipped google and the shannon entroppy seems to use either a probability matrix of AAs over a large sequence set or an MSA. I however had something more simple in mind, like compare the expected amount of AAs for a given sequence with the actual amount and compute a score with that perhaps? Would already move out of actual information theory of course, but could be a simple way to try out things :D

Another idea that was shortly discussed, is to employ an algorithm that searches for repeats of any kind and use this information to build a curriculum.

However, a general question that comes to mind (which could be obvious to some of you, as I did not work much with proteins) Wouild a low complexity protein (e.g. with many repeats) actually make sense to be trained first? These are those that are hardest to work with in modelling normally as far as I know right? So could it make sense to actually do it the other way around? More complex sequences first and less complex last?

I'll be looking forward tpo the discussion :)

@pascalnotin should I include this in the discord?

pascalnotin commented 9 months ago

Hi Max - looking forward to hearing more about the sequence-only entropy based metric. I think we want to stay away from anything that would be based on MSAs for this. Looking for motifs could be interesting -- perhaps something based on: 1) counting the frequent k-mers in the full corpus (we can rely on the fast procedures used in BPE for that) 2) counting the instances of these common k-mers for each sequence. Feel free to post ideas / questions in the discord channel as well!

Muedi commented 9 months ago

the k-mers idea sound good!

To go in the shannon entropy direction: Couldn't we just count the AAs in each sequence and divide by legnth, basically getting a general probability of an AA being anywhere in the given sequence? With this we could then compute the entropy over all AAs to score each sequence on its own.

Muedi commented 9 months ago

My collaeagues from the IDR side, suggested, that since the complexity of a sequence is very different throughout the seq, that a sliding window would make sense.

This package: https://pappulab.github.io/localCIDER/ has a function, that computes a complexity metric for a sequence, while employing the sliding window approach :) This pakg also has stuff like getting info on charged resiues in the seq etc, which could be useful for this issue, but also for the clustering perhaps. Thoughts?

Amelie-Schreiber commented 9 months ago

I have a proposal based on the "intrinsic dimension" of the embeddings of a pre-trained protein language model. This would require computing embeddings of proteins using a pLM, and then computing the intrinsic dimension of those embeddings as a measure of complexity. As a reference, this was used to detect AI generated text (and might be used similarly for proteins). A good reference is this paper: https://arxiv.org/abs/2306.04723

Amelie-Schreiber commented 9 months ago

This can be computed as follows:

import numpy as np
from sklearn.linear_model import LinearRegression
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
from scipy.sparse.csgraph import minimum_spanning_tree

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Input text
text = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"

# Tokenize the input and convert to tensors
inputs = tokenizer(text, return_tensors='pt')

# Get the embeddings
with torch.no_grad():
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[0].numpy()

# Remove the first and last embeddings (<CLS> and <EOS>)
embeddings = embeddings[1:-1]

# Sizes for the subsets to sample
sizes = np.linspace(2, len(embeddings), num=100, dtype=int)

# Prepare data for linear regression
x = []
y = []

for size in sizes:
    # Sample a subset of the embeddings
    subset = np.random.choice(len(embeddings), size, replace=False)
    subset_embeddings = embeddings[subset]

    # Compute the distance matrix
    dist_matrix = np.sqrt(np.sum((subset_embeddings[:, None] - subset_embeddings)**2, axis=-1))

    # Compute the minimum spanning tree
    mst = minimum_spanning_tree(dist_matrix).toarray()

    # Calculate the persistent score E (the maximum edge length in the MST)
    E = np.max(mst)

    # Append to the data for linear regression
    x.append(np.log(size))
    y.append(np.log(E))

# Reshape for sklearn
X = np.array(x).reshape(-1, 1)
Y = np.array(y).reshape(-1, 1)

# Linear regression
reg = LinearRegression().fit(X, Y)

# Estimated Persistent Homology Dimension
phd = 1 / (1 - reg.coef_[0][0])
print(phd)

This is an attempt at an implementation of the estimation of Persistent Homology Dimension described in Intrinsic Dimension Estimation for Robust Detection of AI-Generated Texts. This is a fractal dimension that aims to detect protein sequences generated by AI, but could also be considered a type of complexity measure for each protein sequence.

  1. Model and Tokenizer Initialization: An ESM-2 protein model and tokenizer are initialized. These represent the mapping function $ f: M \rightarrow \mathbb{R}^n $.

  2. Protein Sequence Input and Tokenization: A protein sequence, which can be thought of as a finite subset $ X \subseteq M $, is provided as input and tokenized. The tokens are then converted to tensors, which are the required input format for the ESM model.

  3. Embedding Generation: The token tensors are passed through the model $ f $ to generate a set of embeddings $ Y = f(X) $. Each embedding is a point in $ \mathbb{R}^n $.

  4. Embedding Subsetting: The embeddings corresponding to the first and last tokens are removed. These special tokens do not represent actual amino acids in the protein sequence and so are not part of the subset $ X $.

  5. Subsampling, Distance Matrix Calculation, and MST Calculation: A series of subsets $ S_i \subseteq Y $, $i = 1, \ldots, k $ are created, with sizes $ n_i $ varying from 2 to $ |Y| $. For each subset $ S_i $, a distance matrix is calculated and used to compute the minimum spanning tree (MST). The MST corresponds to a graph $ G \subseteq Y $.

  6. Persistent Score Calculation: The persistent score $E_0^\alpha(S_i) $ is calculated as the maximum edge length in the MST, which corresponds to the lifespans of 0-dimensional features in the PH computation. In this case, $ \alpha = 1 $.

  7. Data Preparation: The sizes of the subsets and the corresponding persistent scores are logged (natural logarithm) and stored for linear regression.

  8. Linear Regression: Linear regression is performed on the log-transformed sizes and persistent scores to approximate the relationship $ \log E_0^\alpha(S_i) \sim (1 - \frac{\alpha}{d}) \log n_i $. The slope of the regression line is then used to calculate the Persistent Homology Dimension (PHD) using the formula $ d = \frac{1}{1 - \text{reg.coef}[0][0]} $, which corresponds to $ d = \text{dim}_0^{PH}(M) $.

In summary, the code is estimating the 0-dimensional Persistent Homology Dimension of the manifold $M$ represented by the protein sequence. This is achieved by generating embeddings for the sequence, calculating the MST for subsets of these embeddings, and then performing linear regression on the sizes of these subsets and their corresponding persistent scores.

Muedi commented 9 months ago

Hi Amelie,

Thanks for the input! And directly with implementation even! :D

This will also likely he helpful for my aforementioned colleagues:)

But generally Isn't this a bit much to run inference on our complete corpus to allow for better training? I don't have a good understanding of how much compute that would need, but seems a bit like shooting sparrows with canons.

Amelie-Schreiber commented 8 months ago

Yes, you're probably right. There was talk of using ESM-2 as a way of scoring complexity of proteins, so it seemed relevant. But it could be a bit overkill to do this for every protein in the training data. I guess I don't have a good understanding of the compute necessary to do it this way either.

Muedi commented 8 months ago

Hm... perhaps @NZ99 has some better idea about the compute necessary here than us?

Also @talkhanz @Leo-T-Zang do you think your data collator updates could be reused with amelies code and some implementation of the single sequence entropy idea? What would be needed to integrate to two?

talkhanz commented 8 months ago

great work guys. the integration for any new CL metric is straightforward, we need to use this function https://github.com/OpenBioML/protein-lm-scaling/blob/0e48cfff6cbd1f3b11b95fbda34db39ad3ee6557/protein_lm/modeling/getters/dataset.py#L51 to set the metric column of a dataset according to our logic. The input_column_name will be the column referring to the column containing our sequences while curriculum_column_name will refer to the new column hosting our metric values. Effectively we are precomputing a new column to be appended to the dataset

The rest is handled by the Trainer.

and set the arguments appropriately at https://github.com/OpenBioML/protein-lm-scaling/blob/0e48cfff6cbd1f3b11b95fbda34db39ad3ee6557/protein_lm/configs/train/toy_localcsv.yaml#L11-L13 or https://github.com/OpenBioML/protein-lm-scaling/blob/0e48cfff6cbd1f3b11b95fbda34db39ad3ee6557/protein_lm/configs/train/toy_hf.yaml#L11-L13

Leo-T-Zang commented 8 months ago

@Muedi, Yes, I think we can simply add them as @talkhanz described. Current CL strategy also supports continous values like perplexity, plDDT. So for entropy and intricate dimension, it should also work. Given all these continous values are computing intensive, I would suggest to precompute them and store with dataset, so we can use them later. For people interested, I just updated my PR #58. Talha and I will further make sure it is working correctly.

Amelie-Schreiber commented 8 months ago

If there is a publication that anyone is working on, I would be happy to contribute a write-up of the intrinsic dimension and its use to determine the complexity of the proteins. Just let me know and I will start working in this.

Muedi commented 8 months ago

So If I wanted to implement these, the best course of action would be functions in a utility script and then connect to the dataset people?

pascalnotin commented 8 months ago

Hi @Muedi @Amelie-Schreiber -- thanks for the great suggestions above! To contribute the ideas within the codebase, you would need to create a PR similar to the one that @Leo-T-Zang just created for ppl and plddt, building on the original PR from @talkhanz. Besides the changes to the codebase (dataset.py in particular) and the test routine, there would be a separate script needed to pre-compute the complexity metrics that you defined above (with a separate test routine). Let me know if that makes sense!

Amelie-Schreiber commented 8 months ago

For posterity and in case someone else gets to this before me, we should make the function for estimating the intrinsic dimension using PHD as follows I think:

import numpy as np
from sklearn.linear_model import LinearRegression
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
from scipy.sparse.csgraph import minimum_spanning_tree

# Load the tokenizer and model
model_path = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = EsmModel.from_pretrained(model_path)

def estimate_persistent_homology_dimension_avg(sequence, num_subsets, num_iterations):
    """
    Estimate the persistent homology dimension of a given protein sequence.

    Parameters:
    - sequence: A string representing the protein sequence.
    - num_subsets: A positive integer indicating the number of subsets of the embedding vectors to use. Max of 2**n where n=len(sequence). 
    - num_iterations: A positive integer indicating the number of iterations for averaging.

    Returns:
    - avg_phd: Average estimated persistent homology dimension.
    """

    phd_values = []  # List to store PHD values for each iteration

    for _ in range(num_iterations):

        # Tokenize the input and convert to tensors
        inputs = tokenizer(sequence, return_tensors='pt')

        # Get the embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state[0].numpy()

        # Remove the first and last embeddings (<CLS> and <EOS>)
        embeddings = embeddings[1:-1]

        # Sizes for the subsets to sample
        sizes = np.linspace(2, len(embeddings), num=num_subsets, dtype=int)

        # Prepare data for linear regression
        x = []
        y = []

        for size in sizes:
            # Sample a subset of the embeddings
            subset = np.random.choice(len(embeddings), size, replace=False)
            subset_embeddings = embeddings[subset]

            # Compute the distance matrix
            dist_matrix = np.sqrt(np.sum((subset_embeddings[:, None] - subset_embeddings)**2, axis=-1))

            # Compute the minimum spanning tree
            mst = minimum_spanning_tree(dist_matrix).toarray()

            # Calculate the persistent score E (the maximum edge length in the MST)
            E = np.max(mst)

            # Append to the data for linear regression
            x.append(np.log(size))
            y.append(np.log(E))

        # Reshape for sklearn
        X = np.array(x).reshape(-1, 1)
        Y = np.array(y).reshape(-1, 1)

        # Linear regression
        reg = LinearRegression().fit(X, Y)

        # Estimated Persistent Homology Dimension for this iteration
        phd = 1 / (1 - reg.coef_[0][0])

        phd_values.append(phd)

    avg_phd = np.mean(phd_values)  # Average over all iterations
    return avg_phd

# Example usage:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
estimated_dimension_avg = estimate_persistent_homology_dimension_avg(protein_sequence, 10, 20)
estimated_dimension_avg

I'm not sure if there is any benefit to using larger models than esm2_t6_8M_UR50D since we are only interested in relative complexity of the proteins (training on lower intrinsic dimension proteins first). So, the smallest ESM-2 model should suffice. This is true for the other proposed measures of complexity too though, no? Also, while a higher number of subsets and iterations may provide better estimates, it will also be a little slower. How much, I'm not sure.

Muedi commented 8 months ago

Hi Amelie,

why do you propose to use subsets instead of taking the complete sequence here? It kind of makes sense to me, as in a sequence there can be very strong variance of complexity, but then, why do we average at the end? :)

I planned to get some coding time for this either tomorrow or on Friday, if you want I can notify you, so that we do not do double work.

Amelie-Schreiber commented 8 months ago

I'm just following the recipe in the paper. I'm going to have to try digging a little deeper into the references in the paper to really understand why this estimate is an estimate of intrinsic dimension. The concept actually feels quite difficult to me depending on the day I'm looking at it. I know the averaging is a way of stabilizing the estimation of the dimension, which has high variance as you mentioned. It's a kind of fractal dimension and different subsets will have different dimension that go from local to global dimension if I am understanding the paper correctly. The example they give in the paper is of a spiral galaxy type point cloud where the small neighborhoods of the arms are closer to 1-dimensional, and larger neighborhoods closer to the center become two dimensional making the dimension somewhere between 1 and 2. The sizes of the subsets are supposed to vary and be from a uniform distribution. The idea is to have a growing subset size to estimate the dimension at different resolutions. They give a more detailed algorithm in the appendix that is more specific than this with the hyperparameters they found best in their experiments. I don't know if these hyperparameters are best for proteins though. I haven't really had the chance to test the intrinsic dimensions of a lot of proteins yet. It's definitely on my list of things to do.

Amelie-Schreiber commented 8 months ago

Just thought I would post this info here. This might be a better implementation than mine. It seems like there are various methods to estimate intrinsic dimension and that there are connections to generalizability. Also, I found this references: paper reference. This reference has an implementation for the GPU computation of persistent homology dimension which can be found here. There is also a non-GPU implementation. I've also written a more faithful version following the algorithm in the Appendix of the first paper I provided, I'll post it in a bit.

Amelie-Schreiber commented 8 months ago

Just in case it is needed, here is the version that replicates the algorithm in the Appendix of Intrinsic Dimension Estimation for Robust Detection of AI-Generated Texts. Please check this in case I made any silly errors:

import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.spatial import distance_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from transformers import AutoTokenizer, AutoModel
import torch

def get_embeddings(text, model_name="facebook/esm2_t6_8M_UR50D"):
    """
    Compute embeddings for each token in the text using a specified model.

    Parameters:
    - text (str): The input text for which embeddings need to be computed.
    - model_name (str): The path to the pretrained model.

    Returns:
    - numpy.ndarray: A matrix where each row is the embedding of a token in the text.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
    with torch.no_grad():
        outputs = model(**inputs)

    # Return embeddings after removing <cls> and <eos> tokens and converting to numpy.
    return outputs.last_hidden_state[:, 1:-1, :].squeeze(0).numpy()

def compute_persistent_score(embeddings):
    """
    Compute the persistent score for a subset of embeddings using the sum of edge weights in the MST.

    Parameters:
    - embeddings (numpy.ndarray): A matrix where each row is an embedding.

    Returns:
    - float: The persistent score for the embeddings.
    """
    dist_matrix = distance_matrix(embeddings, embeddings)
    mst = minimum_spanning_tree(dist_matrix)
    return mst.sum()

def sample_and_score(embeddings, n, k=8, hat_n=40, J=7):
    """
    For various sample sizes, compute the median persistent score across J samples.

    Parameters:
    - embeddings (numpy.ndarray): A matrix where each row is an embedding.
    - n (int): Total number of embeddings.
    - k (int): Number of different sample sizes.
    - hat_n (int): A parameter for determining sample sizes.
    - J (int): Number of samples for each sample size.

    Returns:
    - list: List of sample sizes.
    - list: List of corresponding median persistent scores.
    """
    scores = []
    sizes = [(i - 1) * (n - hat_n) // k + hat_n for i in range(1, k + 1)]

    for size in sizes:
        subset_scores = [compute_persistent_score(embeddings[np.random.choice(n, size, replace=False)])
                         for _ in range(J)]
        scores.append(np.median(subset_scores))

    return sizes, scores

def estimate_dimension(sizes, scores):
    """
    Estimate the intrinsic dimension of the data using linear regression on log-transformed sizes and scores.

    Parameters:
    - sizes (list): List of sample sizes.
    - scores (list): List of corresponding median persistent scores.

    Returns:
    - float: Estimated dimension of the data.
    """
    log_sizes = np.log(sizes).reshape(-1, 1)
    log_scores = np.log(scores)

    reg = LinearRegression().fit(log_sizes, log_scores)
    slope = reg.coef_[0]

    return 1 / (1 - slope)

def estimate_text_dimension(text, runs=3):
    """
    Estimate the intrinsic dimension of the text by repeatedly sampling subsets of its tokens, 
    computing their persistent scores, and then using linear regression on the log-transformed values.

    Parameters:
    - text (str): The input text for which the dimension needs to be estimated.
    - runs (int): Number of runs with different random seeds.

    Returns:
    - float: Estimated dimension of the text.
    """
    embeddings = get_embeddings(text)
    n = embeddings.shape[0]

    slopes = []
    for _ in range(runs):
        sizes, scores = sample_and_score(embeddings, n)
        log_sizes = np.log(sizes).reshape(-1, 1)
        log_scores = np.log(scores)

        reg = LinearRegression().fit(log_sizes, log_scores)
        slopes.append(reg.coef_[0])

    kappa_F = np.mean(slopes)
    return 1 / (1 - kappa_F)

text = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
dimension = estimate_text_dimension(text)
print(f"Estimated dimension of the text: {dimension}")

We could also use TwoNN which is implemented here: https://scikit-dimension.readthedocs.io/en/latest/ The TwoNN method seems to be comparable to the persistent homology dimension method above, but isn't as robust to noise. By the way, I ran some small scale tests on a few hundred natural proteins from UniProt, and compared to some sequences generated by ESM-IF1 for a protein backbone that I generated using RFDiffusion, and it seems like the natural proteins do have higher intrinsic dimension than the AI generated ones. So, this could probably be used to detect AI generated proteins and might be used to improve models if we can find a way to encourage models to learn how to create proteins with higher intrinsic dimension.