johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

General Question: Why is self.in_proj have an expansion again? #20

Closed lucasmgomez closed 5 months ago

lucasmgomez commented 6 months ago

I understand that d_inner is d_model * expansion (E=2) . But why is self.in_proj = nn.Linear(args.d_model, args.d_inner * 2 ...).

Why is the in projection expanded a second time by 2 ?

I can't seem to find the answer in the appointed paper section 3.4.

Any clarification would be appreciated.

johnma2006 commented 6 months ago

Sure! Take a look at the Mamba diagram in Fig 3 of the Mamba paper. The first thing the input x does is split into two branches and go through two linear projections. in_proj is simply a way to compute both those linear projections at the same time in one matmul. They are split apart later: https://github.com/johnma2006/mamba-minimal/blob/master/model.py#L223 (x is the left branch, res is the right branch)

lucasmgomez commented 6 months ago

Thanks that makes sense!