johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.62k stars 191 forks source link

Does matrix A need to be diagonal? #14

Closed houghtonweihu closed 7 months ago

houghtonweihu commented 10 months ago

In your file: https://github.com/johnma2006/mamba-minimal/blob/master/model.py, you have:

A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner) self.A_log = nn.Parameter(torch.log(A)) A = -torch.exp(self.A_log.float()) # shape (d_in, n)

Does matrix A need to be diagonal?

Thanks!

houghtonweihu commented 10 months ago

I got an answer from Dr Dao: d_in is actually the batch dim.

Joeland4 commented 2 months ago

I got an answer from Dr Dao: d_in is actually the batch dim.

what is the batch dim?