johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

Create model.py2 #7

Open elonmasai7 opened 9 months ago

elonmasai7 commented 9 months ago

another view import math import torch from torch import nn from torch.nn.utils import weight_norm

class Mamba(nn.Module): def init(self, d_model, d_state, n_layers, d_inner, dropout=0.1): super().init()

    # We don't want to learn position embeddings.
    # We'll do a simple positional encoding.
    # Note that we divide by sqrt(d_model), which you'll find across other Transformer implementations,
    # and serve the same purpose as with standard attention.
    # `to(torch.float32)` is there only because this code is intended to be seamlessly used with mixed precision.
    self.pos_enc = torch.arange(0, 64, dtype=torch.float32).view(1, -1).to(torch.float32) / math.sqrt(d_model)

    layers = []
    for _ in range(n_layers):
        layers.append(MambaLayer(d_model, d_state, d_inner, dropout=dropout))
    self.layers = nn.ModuleList(layers)

    # Final dense layers.
    self.fc = nn.Linear(d_model, 50257)

def forward(self, x, state_init=None):
    # The input has `l` sequences of length `L` and `b` batch size.
    # `x` has shape: `(l, b, L, d_model)`.
    # We assume the first dimension is the `l` sequence one.
    l, b, L, d = x.shape

    if state_init is None:
        state_init = torch.zeros(l, b, 1, d // 2, dtype=x.dtype, device=x.device)

    x = x + self.pos_enc[:L, None]
    states, outs = [], []

    for layer in self.layers:
        x, state = layer(x, state_init)
        states.append(state)
        # `outs` will eventually have shape `(l, b, L, d)`.
        outs.append(x)

    return self.fc(torch.cat(outs, dim=-1)), torch.cat(states, dim=-2)

class MambaLayer(nn.Module): def init(self, d_model, d_state, d_inner, dropout=0.1): super().init() d_model_half = d_model // 2

    self.lin_A = nn.Linear(d_model, d_model_half)
    self.lin_D = nn.Linear(d_model, d_model_half)

    self.lin_in = nn.Linear(d_model, d_inner)
    self.lin_B1 = nn.Linear(d_inner, d_model_half)
    self.lin_B2 = nn.Linear(d_state, d_model_half)
    self.lin_C = weight_norm(nn.Linear(d_model_half, d_model_half))

    self.dropout = nn.Dropout(dropout)

def forward(self, x, state_init):
    # We output both the state AND the transformed sequence (`x`).
    # The `x` shape is expected to be `(l, b, L, d)`.
    # The `state_init` shape is expected to be `(l, b, 1, n)`.

    l, b, L, d = x.shape
    d_model_half = d // 2

    # We learned to use tanh activation for A and D.
    A = torch.tanh(self.lin_A(x))
    D = torch.tanh(self.lin_D(x))

    a = self.dropout(self.lin_in(x))
    b1 = self.lin_B1(a)
    b2 = self.dropout(self.lin_B2(state_init))
    B = b1 + b2
    c = self.lin_C(self.dropout(A * B))
    state = D * state_init + c[:, :, :, None]

    # It looks like state_init might be off by one timestep from A, B, C, D, but this is
    # not the case because we will start the loop on the 2nd timestep. It is perfectly
    # consistent with the equations of Mamba (see [1] Algorithm 2).
    # Intuitively, we also need to use `state_init` at time `t - 1` rather than `t` to compute
    # `x_t`. Indeed, `state_t - 1` is a consequence of `x_t - 1` and `u_t - 1`.
    # If we were to use `state_t`, this would be equivalent to having `δ_t = 1` instead of
    # `δ_t = 0`, which is the case under the "zero-input" assumption made by the authors
    # (see Equation (7) in [1]).
    x = A * B + C

    # We obtain a new state and a new output sequence `x`.
    return x, state