state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.98k stars 1.1k forks source link

How to extract whole sentence embeddings #476

Open GlancerZ opened 3 months ago

GlancerZ commented 3 months ago

Compared to the method of using the CLS token provided by BERT to extract the entire sentence embedding, is Mamba's method of placing the CLS token effective? My intuition is that the CLS token in Mamba cannot directly interact with each word's token, so its effectiveness might be poor. Therefore, would extracting the last hidden state be more effective? Thanks!

albertfgu commented 3 months ago

Here are a couple strategies:

  1. You can try to use a standard unidirectional model and use a CLS token at the end; or variants such as extracting the last hidden state
  2. You might want to consider bidirectional versions of Mamba instead: https://arxiv.org/abs/2407.09941. In this case, there are several strategies for extracting a sentence level embedding
    • you can add a CLS token anywhere in the sequence
    • you can average the final embeddings
  3. There are more strategies like this one that repeats the sentence, to emulate global context with a causal model https://arxiv.org/abs/2402.15449
yhv-wt commented 2 months ago

I've got good results by simply putting the CLS token at the end of a sequence and using only that token's embedding. Even more so I've got good results by placing multiple CLS tokens within one sequence to extract multiple sub-sequence embeddings with some relations between sub-sequences. In my case, each sub-sequence may depend on its content and something from previous sub-sequences. If you need to capture full cross-sub-sequence relations you'd need to feed the entire sequence into the model twice - the first time to let Mamba learn what's in the sequence and the second time collecting your results. To work with Mamba properly you need to remember the sequential nature of this model (which I consider it's one of the most powerful attributes):

  1. Define the task before giving the context, otherwise model won't know what the context is for.
  2. Retrieve the results after giving the context. Mamba can't give you the answer based on the information it hasn't seen yet.