kyutai-labs / moshi

Apache License 2.0
6.77k stars 532 forks source link

Is the depth transformer essentially a ffn network? #108

Closed TianwenWei closed 1 month ago

TianwenWei commented 1 month ago

Due diligence

Topic

The PyTorch implementation

Question

Hello,

Thank you for your extraordinary work. I have one question regarding to the implementation of your depth transformer network. After examining the code (below), I can see that the depth transformer only accept as input a sequence of length 1, and yield the logit of the token corresponding to the next codebook. As the input sequence length is limited to 1, there is no KV cache involved, in which case in the self attention layer the input token can only attend to itself. As a result, the output of the self attention layer is just the Value, which is nothing more than the matmul of the input vector and the Value matrix.

def forward_depformer(
        self,
        depformer_cb_index: int,
        sequence: torch.Tensor,
        transformer_out: torch.Tensor,
    ) -> torch.Tensor:
        B, K, S = sequence.shape
        assert (
            K == 1
        ), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
        assert (
            S == 1
        ), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
        assert (
            transformer_out.shape[1] == 1
        ), "Transformer out should be a for a single step."

        last_token_input: tp.Optional[torch.Tensor] = None
        depformer_input = transformer_out
        if self.depformer_multi_linear:
            depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
        else:
            depformer_input = self.depformer_in[0](depformer_input)

        if depformer_cb_index == 0:
            last_token_input = self.depformer_text_emb(sequence[:, 0])
        else:
            last_token_input = self.depformer_emb[depformer_cb_index - 1](
                sequence[:, 0]
            )
        depformer_input = depformer_input + last_token_input
        assert depformer_input.shape[1] == 1
        # depformer_input is [B, 1, depformer_dim].
        # The streaming state of the depformer ensures that the proper layer is run.
        dep_output = self.depformer(depformer_input)
        logits = self.linears[depformer_cb_index](dep_output)
        logits = logits[:, None]
        assert logits.dim() == 4, logits.shape  # [B, Ka, S, card]
        return logits
TianwenWei commented 1 month ago

I now understand, the KV cache is automatically handled in class StreamingMultiheadAttention(StreamingModule[_MHAState])