state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.17k stars 1.12k forks source link

Visualization of Delta (post-Softplus) values during the Induction Task #62

Open hrbigelow opened 11 months ago

hrbigelow commented 11 months ago

Following shows the actual values of the input-dependent $\Delta$ during inference of the 2-layer network during the induction-heads task as described in the paper. I successfully trained the model to zero loss on synthetic data of length 256 (took about 40k steps with batch size 8). It achieves zero loss on validation data lengths from $2^6$ to $2^20$ as in the paper. For convenience I only show the values for the length 256 synthetic data.

Interactive heatmap

Here is an interactive heatmap - you can zoom in or use the tooltip.

Preview:

image

The x-axis is the timestep. Each heatmap rectangle is centered on the timestep it represents. The y-axis is organized into 3 large chunks, each representing a different data example. Within each chunk, there are 32 sub-chunks - each represents one of the d_inner = 2 * d_model channels. Within each of those is a pair of lines - the bottom is the layer 1 softplus delta inputs, and the top is layer 2. All three of these examples have the induction token at $t=10$ and the memory token at $t=11$. Then there is an intervening stretch from $t \in [12, 253]$ followed by another occurrence of induction token at $t=254$ and memory token at $t=255$.

I am showing here a heatmap of the Delta values at the moment they go into this call (except that, for purposes of plotting, I applied Softplus to the values since softplus is applied within the selective_scan_cuda.fwd function, and it was easier to grab the values during inference at that point.

Any interpretations welcome. I'm still scratching my head on this one, but in general, I think I could deduce that the second layer must ingest the memory token at $t=11$ in at least one channel, and must somehow learn to ignore the intervening tokens in order to sufficiently preserve the memory, so as to recall it accurately at the end.

Karami-m commented 6 months ago

@hrbigelow Can you please share more details about the model and experiments configuration that you used for this experiments? Based on the the details in the paper I used the following configs:

dataset:
  vocab_size: 16
  input_seq_len: 256
  batch_size: 8

optimizer:
  lr: 1e-3
  weight_decay: 0.1

model:
  config:
    d_model: 64
    n_layer: 2
    ssm_cfg:
      d_state: 16
      d_conv: 4
      expand: 2

I have tested with weight_decay= [0.001, 0.1], but the model was not able to learn to copy the memory.

For the code base, I used Mamba as a sequence mixer in the the code base of Hyena for Induction Head Task.

hrbigelow commented 6 months ago

Hi Karami,

The experiment I did is in mamba-recall. Hopefully that can get you the answers you need. It's been awhile so I don't remember just now, but if I think of it, I'll say more.

Karami-m commented 6 months ago

Thanks for your response. Based on your implementation in link, you modelled it as a classification task, i.e. loss and accuracy of training and test are evaluated only on the last token as such that $L( pred[..., -1, :]), target[..., -1])$

image

rather than next token prediction task that was used for autoregressive models in the the code base of Hyena for Induction Head Task. The classification model seems more reasonable for this task and it is interesting that Mamba has learned to solve the task with such loss as in my experiment it didn't train with last token prediction only.

@tridao , Can you please clarify how you trained Mamba for this task : 1-next token prediction $L( pred[..., :, :]), target[..., :])$ or 2-last token prediction $L( pred[..., -1, :]), target[..., -1])$ ?

hrbigelow commented 6 months ago

Hi Mahdi,

That's a very good point. Yes indeed I trained only on the recall token prediction. I interpreted the phrase "trained on the induction head task" to mean actually trained on the task rather than trained on the typical objective and just evaluated on the task. But it's hard to know if I interpreted it right. My sense though was that the experiment was more meant to test that the architecture was even capable of such a task at all, rather than to produce a realistic scenario.

Henry

On Mon, Apr 22, 2024 at 7:15 AM Mahdi Karami @.***> wrote:

Thanks for your response. Based on your implementation in link https://github.com/hrbigelow/mamba-recall/blob/master/induction-head.py, you modelled it as a classification task, i.e. loss and accuracy of training and test are evaluated only on the last token as such that $L( {y}[..., -1]), x[..., -1])$

image.png (view on web) https://github.com/state-spaces/mamba/assets/17184202/d439f2ca-cc22-4618-b706-e4fb13c0ebde

rather than next token prediction task that was used for autoregressive models in the the code base of Hyena https://github.com/HazyResearch/safari for Induction Head Task https://github.com/HazyResearch/safari/tree/main/configs/experiment/synthetics/induction_head. The classification model seems more reasonable for this task and it is interesting that Mamba has learned to solve the task with such loss as in my experiment it didn't train with last token prediction only.

@tridao https://github.com/tridao , Can you please clarify how you trained Mamba for this task : 1-next token prediction $L( {y}[..., :), x[..., :])$ or 2-last token prediction $L( {y}[..., -1]), x[..., -1])$ ?

— Reply to this email directly, view it on GitHub https://github.com/state-spaces/mamba/issues/62#issuecomment-2069608585, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABI3OFU3HQ4CTP4CAINSXY3Y6ULQBAVCNFSM6AAAAABAYRNMKOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANRZGYYDQNJYGU . You are receiving this because you were mentioned.Message ID: @.***>

albertfgu commented 6 months ago

Yes, I trained directly with a classification loss. It makes sense that without the direct supervision, the model won't be able to learn every mapping in the context (it would be similar to a phonebook lookup task, which pure SSMs are known to not be able to do because of their finite state).

Karami-m commented 6 months ago

Thanks for your clarification. May I know what was your code base and the config? As detailed in https://github.com/state-spaces/mamba/issues/298, I integrated Mamba as a sequence mixer into the Hyena codebase:

  1. for Induction Head Task with a- next token prediction loss $L( pred[..., :, :]), target[..., :])$, could not learn as discussed b- last token prediction $L( pred[..., -1, :]), target[..., -1])$, not tested yet

  2. for associative recall Task which is very similar to Induction head task except that it should memorize all pairs of (key, query) a- next token prediction loss $L( pred[..., :, :]), target[..., :])$, could learn and solve the task b- last token prediction $L( pred[..., -1, :]), target[..., -1])$, could NOT solve the task

EricLina commented 3 months ago

🧐Interesting question.