state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.91k stars 1.1k forks source link

[Need more info] `causal_conv1d` doesn't simulate "shifting of tokens by 1" #57

Open sudhakarsingh27 opened 10 months ago

sudhakarsingh27 commented 10 months ago

If the local conv (or "causal conv1d") is intended to shift the tokens by 1, then this should instead be padding=d_conv instead of padding=d_conv - 1, shouldn't it? (Or can the convolution kernel learn to always ignore the right most element?)

https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L73

hrbigelow commented 10 months ago

Yes, this is a common confusion. But, it has nothing to do with causal_conv1d itself.

The Mamba model, like transformers, is an auto-regressive architecture. This means that if you take $y^l_t$ to be the output of layer $l$ at timestep $t$, then the dependencies are that $y^lt$ depends on $y^{l-1}{<=t}$. Note that it is <=t, and NOT <t. The shift-by-one happens only at the very end of the entire network, where $y^L_t$ produces logits for the $t+1$'th predicted token.

Within this setting, causal_conv1d is just a component within each of the layers. The notion that it is "causal" means that it simply does not look forward. For example: with say d_conv=5, the input would be padded with 4, and we'd have:

pppp123456789
^^^^^
 ^^^^^
  ^^^^^
   ^^^^^
    ^^^^^
     ^^^^^

So, the first output of the causal_conv1d layer would consist of input pppp1, which would include the token 1. But, all of this ultimately will be used to generate logits for predicting the token at position 2.

sudhakarsingh27 commented 10 months ago

This helps me connect the dots, thanks! (I was trying to connect this causal_conv1d action with Attention mechanism in "Attention Expressivity" section D.2 in H3 paper)

sudhakarsingh27 commented 10 months ago

I think it also has to do with causal_conv1d because attention mechanism helps in mixing across the sequence dimension which helps attention accomplish tasks like "associative_recall" and "induction_heads". Authors of H3 try to emulate that with "shifting" tokens (with a "shift" SSM) while Mamba authors seem to have approximated even that to just a causal convolution. Without this, wouldn't SSMs be just "fancy Gated RNNs or LSTMs"?

PheelaV commented 6 months ago

Hi @albertfgu

Could you please confirm the above or elaborate? Mamba paper makes exactly one reference to "shift-ssm" from H3, I might have missed it but besides the Figure 3 and the mention of local convolution of H3 in subsection SSM Architectures there are not any comments. Pattern matching implementation similarities and researching the papers has led me here. I am looking for the intuition behind the conv1d operation in each of the Mamba blocks. In H3 shift-ssm is argued to help approximate the attention multiplicative interactions and together with the diag-ssm that brings in the attentioin memorization akin to soft lookup in order to facilitate the associative search ability. Thanks.