Open nleroy917 opened 1 year ago
hey! curious what kind of data you are using it for
would actually recommend https://github.com/lucidrains/perceiver-pytorch#experimental , which would be doing what you want with isab, with all the tiny details that made perceiver work
Oh nice I didn't know that. Thank you! I will promptly check this out.
Specifically, we are interested in genomic region and gene transcription data. So, a "bag of genes" or a "bag of transcription factor binding sites". One can think of it as a sentence, but there's no inherent order - so positional encodings arnt necessary nor do they make any sense for this context.
In the past with other language models we've fudged this by randomly shuffling the "bag of genes" to simulate context by would be nice to not have to do that.
@nleroy917 that is indeed a valid use case! let me know how your experiments with the ISAB-perceiver fairs and maybe i'll put a bit more polish into it
@lucidrains small update! (we can move this discussion over to the Perceiver
repo if necessary since this is slightly outside the ISAB scope.)
I read the perciever paper and agree it would be a nice fit here. I had a question, however. I noticed the .experimental
Perceiver
implementation requires the specification of number of input channels and frequency bands. this doesn't necessarily apply to transcription or region data since it truly is just a list of genes or regions. It can directly be thought of as a sentence with no order. So, there is really only one input
channel, and only one frequency_band
. Its just ["gene1", "gene2", ...]
(We literally represent the data as strings).
I was looking at the PercieverLM
and was wondering if I could just remove the positional embeddings entirely:
class PerceiverLM(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
max_seq_len,
**kwargs
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
# self.pos_emb = nn.Embedding(max_seq_len, dim)
# remove since data is permutation invariant
self.perceiver_io = PerceiverIO(
dim = dim,
queries_dim = dim,
logits_dim = num_tokens,
**kwargs
)
def forward(
self,
x,
mask = None
):
n, device = x.shape[1], x.device
x = self.token_emb(x)
# pos_emb = self.pos_emb(torch.arange(n, device = device))
# pos_emb = rearrange(pos_emb, 'n d -> () n d')
# x = x + pos_emb
# remove since data is permutation invariant
logits = self.perceiver_io(x, mask = mask, queries = x)
return logits
I wasn't able to elucidate the exact mechanism implemented by Perciever - I am still mostly working off my knowledge from attention is all you need. The training task is going to be BERT-like.
Let me know if you have any suggestions. Thank you again for the insight!
@nleroy917 yup, you could, but each nucleotide within each gene will still need positions?
@nleroy917 yea, actually the more i think about it, it is still better to use regular language modeling (or perhaps the new hyena dna if you need to capture long distance interactions) for the nucleotides or k-mers within each gene. then you can apply perceiver, or a regular transformer followed by pooling for some invariant prediction, on the set of genes
@nleroy917 are you doing supervised or unsupervised learning? if you are doing unsupervised, i'm not sure if there is anything to be pulled off the shelf
Hi @lucidrains
The task would be unsupervised; masked language modeling to be specific with various downstream fine-tuning tasks. So, BERT.
To be more clear, I am interested in replication of the results of this paper: https://www.nature.com/articles/s41586-023-06139-9.epdf?sharing_token=u_5LUGVkd3A8zR-f73lU59RgN0jAjWel9jnR3ZoTv0N2UB4yyXENUK50s6uqjXH69sDxh4Z3J4plYCKlVME-W2WSuRiS96vx6t5ex2-krVDS46JkoVvAvJyWtYXIyj74pDWn_DutZq1oAlDaxfvBpUfSKDdBPJ8SKlTId8uT47M%3D
They take a "bag of genes" approach but use the standard transformers
library from hugging face (which doesn't use any experimental/cutting-edge attention as far as I am aware), and I would be interested in assessing performance of a more appropriate architecture, like a set transformer.
So, litearlly, the "sentences" are lists of Gene IDs ["ENS0001, "ENS0002, ... , ENS00100"]
, which get tokenized into specific genes (no byte pair encoding, etc, its akin to splitting on white space). Vocab size of ~25,000. Hope that makes sense.
TL;DR I'm interested in a genomic approach when the modality isn't raw sequence information, rather features extracted from a genomic experiment that have no inherent order.
Pseudo-code:
gene_tokenizer = GeneTokenizer()
data = read_scRNA_dataset("/path/to/matrix.h5ad")
model = SetTransformer(...)
for cell in data:
gene_tokens = gene_tokenizer(cell)
# ["ENS0001", "ENS0002", ... ]
gene_token_ids, mask = gene_tokenizer.encode_tokens(gene_tokens)
# gene_token_ids = [102, 1034, 4562, ....]
# mask = [1, 1, 1, ... ]
out = model(gene_token_ids, mask=mask)
ah i see, i actually think self-attention will be hard to beat, even if it does not account for permutation invariance
a lot of the exotic types of attention don't work out. in fact, i've tried ISAB architecture for a contracting project and it underperformed regular self-attention. however, if you are trying to innovate for longer context, i think that is worth exploring
what is the average number of tokens per gene, and how many genes typically?
Wow, that's interesting. I appreciate that insight and anecdote. We've used Word2Vec in the past, and it can learn great associations even though the data is permutation invariant.
what is the average number of tokens per gene, and how many genes typically?
I've yet to read the paper in its entirety, but I think its one token mapped to one gene, and they reported ~25,000 genes in the paper. I'm curious how you think self-attention might scale?
I know that a big area right now is scaling up context, but assuming a context of ~2048, how might self-attention respond to a vocab of 250,000? Or a million? Does the memory complexity just explode or does the model struggle to learn?
@nleroy917 memory is no longer an issue, with recent advances like flash attention, however, 25k is definitely nearing the upper limit for sequence length (still paying the quadratic compute)
without reading the paper, is geneformer is already doing self-attention on 25k tokens? i don't think that is possible with naive self attention. for 250k to 1 million, then i think you are onto something and agree on trying ISAB / perceiver with modifications
i guess we should both read the geneformer paper first haha
also saw this some time ago which may be relevant
Sorry if that was confusing.
To be specific: geneformer used a context size of 2048 (max_seq_length
), while their number of unique tokens 25,000.
I was inquiring about scaling up the vocab to ~1 million. I.e. your input embeddings are a matrix of dimensionality 1 million rows by ~100. That paper looks interesting and I'll check it out. I appreciate all the insight its super helpful and giving me a clearer picture here.
@nleroy917 if you are referring to the vocabulary size (number of unique genes), 1 million is fine
if you are using vocab
in place of the input (which is gene_token_ids
in your example above, being passed into model
), then you are right that ISAB would be the best approach. flash attention can go up to about 16-32k
Sorry if that was confusing.
To be specific: geneformer used a context size of 2048 (
max_seq_length
), while their number of unique tokens 25,000.I was inquiring about scaling up the vocab to ~1 million. I.e. your input embeddings are a matrix of dimensionality 1 million rows by ~100. That paper looks interesting and I'll check it out. I appreciate all the insight its super helpful and giving me a clearer picture here.
ok i think i understand, you'd like to basically shoot for longer context than geneformer, which is capped at 2048
if you are working with gene ids, your input exceeds 16k unique genes, and no nucleotide sequences being passed in, then agreed experimental PerceiverLM
should be tried. i think you'll also have to think about how to approach set prediction. maybe DETR paper has some clues
@lucidrains small update! (we can move this discussion over to the
Perceiver
repo if necessary since this is slightly outside the ISAB scope.)I read the perciever paper and agree it would be a nice fit here. I had a question, however. I noticed the .
experimental
Perceiver
implementation requires the specification of number of input channels and frequency bands. this doesn't necessarily apply to transcription or region data since it truly is just a list of genes or regions. It can directly be thought of as a sentence with no order. So, there is really only oneinput
channel, and only onefrequency_band
. Its just["gene1", "gene2", ...]
(We literally represent the data as strings).I was looking at the
PercieverLM
and was wondering if I could just remove the positional embeddings entirely:class PerceiverLM(nn.Module): def __init__( self, *, dim, num_tokens, max_seq_len, **kwargs ): super().__init__() self.token_emb = nn.Embedding(num_tokens, dim) # self.pos_emb = nn.Embedding(max_seq_len, dim) # remove since data is permutation invariant self.perceiver_io = PerceiverIO( dim = dim, queries_dim = dim, logits_dim = num_tokens, **kwargs ) def forward( self, x, mask = None ): n, device = x.shape[1], x.device x = self.token_emb(x) # pos_emb = self.pos_emb(torch.arange(n, device = device)) # pos_emb = rearrange(pos_emb, 'n d -> () n d') # x = x + pos_emb # remove since data is permutation invariant logits = self.perceiver_io(x, mask = mask, queries = x) return logits
I wasn't able to elucidate the exact mechanism implemented by Perciever - I am still mostly working off my knowledge from attention is all you need. The training task is going to be BERT-like.
Let me know if you have any suggestions. Thank you again for the insight!
in which case, yea, you are right that removing the positional information should be fine. i misunderstood and thought you had the nucleotide sequences for each gene basically flattened as the input
in which case, yea, you are right that removing the positional information should be fine. i misunderstood and thought you had the nucleotide sequences for each gene basically flattened as the input
Yeah, there's actually no sequence information. It's even simpler than that... just a list of genes. The issue we've always thought of is that a "bag of genes" has no inherent order. You can shuffle it and its still the same, no information lost.
@nleroy917 ok, let's start with being able to turn off the fourier encodings
Omg.
Thank you for that. So, to recap you still think self-attention is superior?
it is if you are at 8 to 16k
realistically, you won't scale up to longer contexts without some struggle. you'll likely need some curriculum learning, no matter which arch you choose. Isab perceiver MLM may also have stability issues, given you are off the beaten path
I think perhaps first you should test your hypothesis that more context would help by substituting the self attention in geneformer with flash attention and fine tuning to 4096 or 8192. If you see a signal, then shoot for 25k with perceiver and see how many latents are required to maintain expressiveness
here is another paper that used something like isab https://github.com/lucidrains/recurrent-interface-network-pytorch for generative modeling, for reference. I can count on one hand the number of papers that use this, so you are breaking new grounds
I think perhaps first you should test your hypothesis that more context would help by substituting the self attention in geneformer with flash attention and fine tuning to 4096 or 8192. If you see a signal, then shoot for 25k with perceiver and see how many latents are required to maintain expressiveness
That makes a ton of sense. I'll start here. This discussion cleared up lots of areas of confusion and gave me lots of starting points, so thank you. Appreciate all the work you do! I'll try to provide updates as they come (I work slow since much of this is outside my area of expertise 🥲)
yea no problem! keep me updated
Hi!
Thank you so much for implementing this. My lab has had the set transformer on the radar for a bit since we have some permutation invariant data we are working with... Do you have an example of how one might utilize this with x-transformers? My initial idea was something like:
This is just based on reading Vaswani2017, Lee2019, and poking around the code. I see that
Encoder
is just inheritingAttentionLayers
, so I wasn't sure if it could be a direct swap.This is way outside my area of expertise, so I thought I would ask here first before chasing myself down a rabbit hole.
Thanky you!