lucidrains / enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
MIT License
435 stars 82 forks source link

Distribution of embedding values #18

Closed sofroniewn closed 1 year ago

sofroniewn commented 1 year ago

I'm using the embeddings generated by enformer for another project, and while I'm getting meaningful results from my own linear output heads I was struck by the distribution of the embedding values themselves.

If I run the following

import matplotlib.pyplot as plt
import numpy as np
import torch
from enformer_pytorch import Enformer

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
seq = torch.randint(0, 5, (2, 196_608)) # for ACGTN, in that order (-1 for padding)
output, embeddings = enformer(seq, return_embeddings=True)

plt.hist(embeddings.detach().numpy().ravel(), bins=np.linspace(-0.5, 0.5, 200), density=True);
plt.xlim([-0.5, 0.5]);

I see the following plot

image

I can confirm I get a very similar distribution of values using real sequences instead of random noise.

I'm struck by the sharp cutoff at around -0.16, and the U shape distribution for negative values, the strong peak at zero, and the tail of positive values. Naively I would have expected to see something more normally distributed around zero, but I guess I don't have a lot of experience looking at outputs of transformers. I'm curious if this distribution is to be expected based on the model or if it is at least of no concern to experts? I'm also curious if any DeepMind provided embedding values have the same distribution. I havn't downloaded any, but could try looking into that if it was of broader interest.

As noted above I do get alright predictions from my own linear output heads, but my I'm struggling to match the distribution of target values and the distribution of my outputs and so was curious if this could be one contributing reason, though there could be many other factors at play that are specific to my application.

lucidrains commented 1 year ago

@sofroniewn Hi Nicholas

Yes, you are correct - typically in transformers we add a layernorm at the very end of all the attention / feedforward blocks, just before the linear head projection. Embeddings would be taken after this layernorm, and would probably have more of a distribution you are used to seeing

However that layernorm seems to be absent in the Enformer architecture

You can always add an additional transformer block followed by this layernorm and learn it yourself

lucidrains commented 1 year ago

@sofroniewn are you using one of the fine tuning modules? i could add this as an option, provided you share whether this helps or not for what you are working on

lucidrains commented 1 year ago

ohh i see, Enformer does have a final normalization, but in the form of a batchnorm (within the final pointwise convolutional block) https://github.com/lucidrains/enformer-pytorch/blob/main/enformer_pytorch/modeling_enformer.py#L334-L339 the current heads are projected from after a GELU activation in the final pointwise, which explains what you are seeing

lucidrains commented 1 year ago

@sofroniewn could you try setting this to True for your task in v0.6.0? and also i'd be curious what the resulting distribution looks like after fine tuning

sofroniewn commented 1 year ago

Hi Phil

Thanks for the detailed response and code option. So far I've actually just been using the embeddings generated from the pre-trained network and feeding them into my own custom head without any fine tuning of the enformer. Right now I'm also just taking a single embedding vector at the midpoint of the sequence. My current task is analogous to the CAGE predictions from the TSS in the enformer paper.

I just tried adding in my own layernorm before my custom head and performance went down very slightly, still using the fixed embeddings as inputs. I will investigate this further though before drawing any strong conclusions as I have a couple different tasks I can try this on.

I have been interested in trying fine tuning the enformer but have shied away from it so far due to computational demands and the fact that my task right now just uses a single embedding vector from the whole sequence. I might revisit this soon though and start trying to fine tune - in which case I will try both with and without the post_transformer_embed option and report back.

Thanks again for your help here!