Closed hrbigelow closed 11 months ago
The amount of information the model can store is proportional to its state size. So in this task, the reason it's able to generalize perfectly is because there isn't much information it has to remember (just one token).
Your question might be touching on a second point of more nuance, although I'm not exactly sure if this is what you mean. But the vocabulary size also implicitly affects memorization ability, because the larger the vocab the more memory is required to represent one token (e.g. information theoretically with a uniform distribution, log(|vocab|) bits). And so a finite state size can also only support a maximum vocabulary in principle.
Hi Albert,
Thank you for the explanation and sorry, I do think I was unclear. I understand your point about the informational capacity and vocab size. But it seems the problem comes a bit upstream from that, namely that Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token
vs. all-other-tokens. I do see that in the input-dependent version, the B and C parameters are also learning something, but they are also just linear so it doesn't seem to resolve this quandary. Also, actually I'm not asking even about generalization ability - I just can't see how such a model would be able to achieve 100% training accuracy.
From this section:
my understanding from the following section would be that the synthetic dataset is generated like:
EDIT: forgot to add in the special tok a second time
from random import choice, choices
def random_seq(L=256, V=16, P=10):
"""
Section E.1
Training consists of randomly generating data every step, with a batch size of 8.
"""
# prefix is the region where the special token first occurs
assert P < L - 2
vocab = list(range(V))
memory_tok = choice(vocab)
other_toks = [ t for t in vocab if t != memory_tok ]
# Section 3.1 from https://arxiv.org/pdf/2212.14052.pdf the 'special token'
special_tok = V
seq = choices(vocab, k=P) + [special_tok, memory_tok] + choices(other_toks, k=L-P-2) + [special_tok, memory_tok]
return seq
if __name__ == '__main__':
L, V, P = 20, 5, 3
print(f'seq_length={L}, vocab_size={V}, prefix={P}')
for b in range(20):
print(random_seq(L, V, P))
"""
seq_length=20, vocab_size=5, prefix=3
[4, 0, 3, 5, 3, 0, 1, 4, 0, 0, 1, 2, 1, 0, 2, 2, 4, 1, 2, 1, 5, 3]
[4, 3, 1, 5, 0, 4, 4, 2, 3, 4, 4, 1, 3, 4, 4, 2, 2, 2, 4, 3, 5, 0]
[0, 4, 3, 5, 1, 4, 2, 2, 3, 2, 3, 4, 4, 3, 3, 0, 0, 4, 3, 0, 5, 1]
[1, 2, 3, 5, 1, 2, 3, 2, 3, 2, 2, 3, 0, 0, 0, 4, 2, 0, 2, 2, 5, 1]
[2, 3, 2, 5, 2, 1, 4, 1, 4, 3, 1, 1, 4, 4, 3, 4, 4, 4, 0, 0, 5, 2]
[4, 1, 4, 5, 4, 0, 2, 1, 0, 3, 1, 3, 0, 1, 1, 1, 3, 2, 2, 2, 5, 4]
[3, 4, 0, 5, 3, 1, 1, 1, 0, 1, 2, 2, 0, 2, 0, 2, 4, 4, 0, 2, 5, 3]
[3, 3, 0, 5, 2, 0, 4, 3, 1, 1, 1, 1, 3, 0, 3, 4, 3, 4, 4, 1, 5, 2]
[3, 4, 4, 5, 4, 3, 0, 0, 3, 3, 1, 2, 0, 3, 3, 3, 0, 1, 2, 1, 5, 4]
[3, 1, 3, 5, 2, 1, 0, 0, 4, 1, 1, 3, 1, 3, 4, 3, 1, 1, 3, 3, 5, 2]
[1, 1, 4, 5, 2, 0, 4, 1, 0, 4, 1, 3, 4, 4, 1, 0, 1, 1, 0, 3, 5, 2]
[4, 2, 0, 5, 3, 0, 2, 0, 1, 4, 2, 0, 1, 4, 1, 4, 4, 0, 1, 0, 5, 3]
[0, 1, 1, 5, 0, 3, 4, 1, 1, 1, 3, 1, 2, 4, 1, 3, 2, 1, 4, 3, 5, 0]
[2, 2, 2, 5, 1, 0, 2, 4, 0, 4, 3, 4, 4, 0, 0, 3, 4, 4, 3, 3, 5, 1]
[0, 3, 3, 5, 2, 3, 4, 3, 1, 0, 4, 1, 1, 1, 4, 3, 1, 1, 1, 0, 5, 2]
[0, 2, 3, 5, 0, 1, 1, 1, 2, 3, 4, 1, 3, 2, 1, 3, 4, 2, 3, 1, 5, 0]
[0, 4, 3, 5, 0, 1, 1, 1, 2, 1, 1, 2, 2, 3, 1, 2, 4, 2, 4, 4, 5, 0]
[3, 2, 4, 5, 4, 1, 0, 0, 1, 1, 1, 1, 0, 0, 3, 1, 2, 3, 3, 1, 5, 4]
[0, 3, 3, 5, 4, 1, 0, 2, 3, 2, 0, 0, 1, 3, 1, 1, 3, 3, 1, 3, 5, 4]
[3, 0, 3, 5, 4, 2, 3, 3, 3, 1, 3, 3, 2, 1, 0, 2, 1, 1, 1, 0, 5, 4]
"""
So, in the training data, every token in the 16-token vocab will eventually be used as the memory_tok
. My simplistic view at the moment then is that the Delta operator must learn to evaluate to a positive value for memory_tok
so that \bar{B} evaluates positively, and at the same time must evaluate to near zero for all other tokens != memory_tok. But, Delta is just a perceptron basically, so it's not possible for it to simultaneously learn to perform this round-robin separation of memory tokens from other tokens for each possible memory_tok.
For reference, I've tried to summarize the real-valued version of the algorithm in einsum as:
Source colab here
Please do let me know if I made a mistake!
It's a 2 layer model, not 1 layer. I think you might be right that a single (S6) layer can't learn this task. Although, a single Mamba block probably can, because of the local convolution before the main SSM.
Thanks for your response. I was aware it was a two-layer model actually - sorry I should have mentioned that.
To be clear, this is the graphical model structure if I'm not mistaken. Hidden state at layer $l$ time $t$, $h^lt$ has a Markov blanket of $h^l{t-1}$, $h^{l-1}_t$ and $x_t$ (due to residual connections).
layer2 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
layer1 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
input . . S M . . ... S M
pos 0 1 2 3 4 5 T-1 T
Through data processing inequality, $I(ht ; h{t-1}) \le I(ht ; h{t-2}) \le I(ht ;h{t-3}) ...$ (for either layer) So, hidden states at time 3 must retain the 'M' information across that vast stretch. And the only way they can do it is to ignore the majority of intervening information coming in. Seems the only way to ignore this is through the $\Delta(x_t)$ functions acting on B. But the $\Delta$ functions are just linear separators of the input so I don't see how they can effectively do that separation.
I'm sure I'm missing something basic though - obviously the model does solve the task.
After the first layer, the representations have all been mixed. Let me use your diagram:
layer2 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
layer1 * -> * -> * -> o -> * -> o -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
input . . S M1 . M2 ... S M
pos 0 1 2 3 4 5 T-1 T
If I'm understanding correctly, your objection is that
While this is true, the point is that the two o
marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.
Or perhaps I'm misunderstanding your question still. Reading it again, I don't quite understand this phrase that you've repeated a few times:
Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token vs. all-other-tokens
I think maybe another point is that the model isn't classifying M (memory-token) vs all-other-tokens. It's classifying S. All it needs to do is know that it's seen S, and from then on ignore everything else.
After the first layer, the representations have all been mixed. Let me use your diagram:
layer2 * -> * -> * -> * -> * -> * -> ... * -> * ^ ^ ^ ^ ^ ^ ^ ^ | | | | | | | | layer1 * -> * -> * -> o -> * -> o -> ... * -> * ^ ^ ^ ^ ^ ^ ^ ^ | | | | | | | | input . . S M1 . M2 ... S M pos 0 1 2 3 4 5 T-1 T
If I'm understanding correctly, your objection is that
- M1 at the bottom is important while M2 should be ignored
- but M1 and M2 have the same representation, so how can the model achieve this?
While this is true, the point is that the two
o
marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.
Ahh yes this does help clarify.
So at a high level, what I was trying to understand is whether Mamba's ability to perfectly solve the induction head task across such long context depends crucially on the fact that $\Delta(u_l)$ is input-dependent. My simplistic picture was that maybe $\Delta(u_l)$ learns to output near zero values for some tokens, so that in the recurrence relation:
$h_l = \exp(\Delta(ul)A) h{l-1} + \Delta(u_l) B(u_l) u_l$
when $l$ is in the intervening positions after the first 'M' (the positions we want the model to 'ignore'), I thought that maybe the model achieves this ignoring through the second term $\Delta(u_l) B(u_l) u_l$ being close to zero for most of those timesteps.
I had thought maybe this ability for Mamba to ignore long stretches was due to this idea of input-dependent gating, described here:
But that's probably not the case, right? Maybe what's really going on is that those terms aren't near zero during the recurrence over those million or so intervening tokens. Rather, the state $h_l$ just has enough subspaces that the folding in of so much more information still does not disturb the information from the first M.
Thanks again!
By the way, I do want to be mindful not to clutter github issues with theoretical discussions if that is not appropriate. But I am very grateful for your answers - this is a truly exciting development.
I think I still don't understand your question. Letting $\Delta \to 0$ where it wants to (in particular once it's memorized the token, the part of the hidden state that stores the memory token should have $\Delta=0$ for subsequent timesteps) is precisely the motivation for the input-dependent selection.
It would be cool for someone to verify by actually looking at the learned activations.
I think I still don't understand your question. Letting $\Delta \rightarrow 0$ where it wants to (in particular once it's memorized the token, the part of the hidden state that stores the memory token should have $\Delta = 0$ for subsequent timesteps) is precisely the motivation for the input-dependent selection.
I'm starting to read your thesis by the way - there is too much background I'm unfamiliar with at the moment.
That would be very interesting if true! I still don't see it though. Because, in the very first layer, where $u_l$ are freshly embedded tokens before any mixing, I don't see how $\Delta(u_l)$ could be close to zero for all of the million or so intervening token inputs - $\Delta(u_l)$ is just defined as $\tau_t(Parameter + Linear_N(u_l))$. So, in your example above, at the very least, M2 will produce the same $\Delta$ value as M1.
In the second layer, $u_l$ are now richer representations including history, so it's harder to reason about them. But $\Delta(u_l)$ in the second layer is still just a linear separator so again it's hard to imagine a pattern where we have $\Delta(u_3) \gt 0$ and $\Delta(u_i) \approx 0$ for all $i \in [4, 10^6]$ (position 3 being the occurrence of M1)
It would be cool for someone to verify by actually looking at the learned activations.
Good idea, if I get a chance I'll try to test that.
Yes, I'm referring to the 2nd layer as previously discussed. In the second layer, you're not working with $u_t$, but $y_t$, the outputs of the first layer. And again, what the model needs to actually operate on is not the memorization tokens but the "induction token" S. It's easy for the first layer to construct representations $y_t$ that encode whether or not S has been previously seen.
I think it would be a great exercise to write down a closed-form mechanism that solves this task, which I believe is fairly simple, and empirically check if a trained model learns a similar mechanism.
Hi,
Looking at table 2 in the paper, I was astonished to see the result.
As I understand it, the synthetic data task produces:
where:
[S] is a special symbol [A] is the token to remember [B...] are up to 1M arbitrary tokens
and the task is to predict [A] after the second occurence of [S].
Given the finite size of hidden state, this implies Delta(x_i) should be ~ 0 for almost every token in [B...], otherwise the hidden state would "overflow" with information and drown out the original [A].
But the Delta function is just parameterized by a single N x 1 vector (plus possible bias). So it seems to me that because the Delta(*) function is content-specific, it can't really be used to truly facilitate unlimited in-context learning for any choice of [A] and intervening [B...]. Is that a correct intuition?
Thanks in advance,
Henry