state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.62k stars 1.06k forks source link

Encoder-decoder architecture #78

Open sentialx opened 9 months ago

sentialx commented 9 months ago

What would be the preferred way to make an encoder-decoder architecture with Mamba? I tried concatenating embeddings to decoder inputs with no luck. My use case is a diffusion model and the encoder would be used for conditioning

albertfgu commented 9 months ago

It's still an open question how to do this with SSMs.

ElliottDyson commented 7 months ago

What would be the preferred way to make an encoder-decoder architecture with Mamba? I tried concatenating embeddings to decoder inputs with no luck. My use case is a diffusion model and the encoder would be used for conditioning

Due to the inherent nature of the model I struggle to see how one might make an encoder-only component. Which makes sense as to why it's still an open question as @albertfgu stated. Hopefully this is a question that is solvable though, as until then I struggle to see how this model will translate to non-autoregressive (non-continuative) or multi-modal problem domains.

stanleyshly commented 5 months ago

@ElliottDyson Do you still have the code from your attempt? What went wrong; was the model just not converging?

@albertfgu To your knowledge, has any work been done on vector to vector Mamba or Mamba derivative models?

ElliottDyson commented 5 months ago

@ElliottDyson Do you still have the code from your attempt? What went wrong; was the model just not converging?

@albertfgu To your knowledge, has any work been done on vector to vector Mamba or Mamba derivative models?

As for the code, I'm afraid not, never got that far yet due to other projects. As for Mamba derivatives, have a flick through the most recent pages of the papers on huggingface, there's been a few.

stanleyshly commented 5 months ago

Thank you for your response. I don't see any Mamba models that do vec2vec though, do you have a link to any?

ElliottDyson commented 5 months ago

Thank you for your response. I don't see any Mamba models that do vec2vec though, do you have a link to any?

Something along the lines of this may work (sequence length of 1, fixed regression task): Mamba class forward method:

def forward(self, hidden_states):
    batch, seqlen, dim = hidden_states.shape
    assert seqlen == 1, "For regression, the input should be a single vector"

    # We do matmul and transpose BLH -> HBL at the same time
    xz = rearrange(
        self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
        "d (b l) -> b d l",
        l=seqlen,
    )
    if self.in_proj.bias is not None:
        xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

    x, z = xz.chunk(2, dim=1)

    # Compute short convolution
    x = self.act(self.conv1d(x)[..., :seqlen])

    x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
    dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
    dt = self.dt_proj.weight @ dt.t()
    dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
    B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
    C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()

    assert self.activation in ["silu", "swish"]
    y = selective_scan_fn(
        x,
        dt,
        A,
        B,
        C,
        self.D.float(),
        z=z,
        delta_bias=self.dt_proj.bias.float(),
        delta_softplus=True,
        return_last_state=False,
    )

    y = rearrange(y, "b d l -> b l d")
    out = self.out_proj(y)
    return out.squeeze(1)  # Remove the sequence dimension

Or if you meant sequence of vectors to sequence of vectors (original implementation but continuous instead of tokenised), then try this instead for the mamba forward block:

def forward(self, hidden_states):
    batch, seqlen, dim = hidden_states.shape

    # We do matmul and transpose BLH -> HBL at the same time
    xz = rearrange(
        self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
        "d (b l) -> b d l",
        l=seqlen,
    )
    if self.in_proj.bias is not None:
        xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

    x, z = xz.chunk(2, dim=1)

    # Compute short convolution
    x = self.act(self.conv1d(x)[..., :seqlen])

    x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
    dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
    dt = self.dt_proj.weight @ dt.t()
    dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
    B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
    C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()

    assert self.activation in ["silu", "swish"]
    y = selective_scan_fn(
        x,
        dt,
        A,
        B,
        C,
        self.D.float(),
        z=z,
        delta_bias=self.dt_proj.bias.float(),
        delta_softplus=True,
        return_last_state=False,
    )

    y = rearrange(y, "b d l -> b l d")
    out = self.out_proj(y)
    return out

Please let me know if this ends up working 🙂

Of course, you'll also need to change the input and output preprocessing to not use tokenisation, and change the loss function for training too.