facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.29k stars 644 forks source link

There is abnormal hot line in the embedding along the sequence length dimension? #599

Open chemgeeklian opened 1 year ago

chemgeeklian commented 1 year ago

Discussed in https://github.com/facebookresearch/esm/discussions/598

Originally posted by **chemgeeklian** August 4, 2023 There are one or more abnormal hot lines in the embedding along the sequence length dimension. Because of this hot line, training downstream models with mean pooling embeddings will lead to pathological behavior. **Reproduction steps** ```python import torch import esm import matplotlib.pyplot as plt model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_5() # I have tested several ESM pretrained models, see figure below batch_converter = alphabet.get_batch_converter() model.eval() data = [ ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE") ] batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) with torch.no_grad(): results = model(batch_tokens, repr_layers=[30], return_contacts=True) esm_embed = results["representations"][30] fig, axes = plt.subplots(2,1, figsize=(20, 10), sharey=True, sharex=True) c1 = axes[0].imshow(esm_embed[0], cmap='hot', interpolation='nearest') c2 = axes[1].imshow(esm_embed[1], cmap='hot', interpolation='nearest') modelname = 'esm1v_t33_650M_UR90S_5' fig.colorbar(c1, ax=axes, orientation='vertical') plt.title(modelname) plt.savefig(modelname+'.png') plt.show() ``` Output results (name of models used are shown in each figure): x axis is the dimension of embedding y axis is the dimension of sequence length esm2_t30_150M_UR50D: hotline around x ~ 560 ![esm pretrained esm2_t30_150M_UR50D](https://github.com/facebookresearch/esm/assets/15936589/f36846b4-50ea-454a-8206-9fc1cd7e789a) esm2_t33_650M_UR50D: hotline around x ~ 1180 ![esm pretrained esm2_t33_650M_UR50D](https://github.com/facebookresearch/esm/assets/15936589/2b55b2ce-05dd-47c5-91b2-dd4129f316b7) esm1v_t33_650M_UR90S_5: hotline around x ~ 820 ![esm1v_t33_650M_UR90S_5](https://github.com/facebookresearch/esm/assets/15936589/7d3ed003-39e5-4370-b3a1-16554a7bef2a)