[X] I have done my due diligence in trying to find the answer myself.
Topic
The PyTorch implementation
Question
Hello,
Thank you for your extraordinary work.
I have one question regarding to the implementation of your depth transformer network. After examining the code (below), I can see that the depth transformer only accept as input a sequence of length 1, and yield the logit of the token corresponding to the next codebook. As the input sequence length is limited to 1, there is no KV cache involved, in which case in the self attention layer the input token can only attend to itself. As a result, the output of the self attention layer is just the Value, which is nothing more than the matmul of the input vector and the Value matrix.
def forward_depformer(
self,
depformer_cb_index: int,
sequence: torch.Tensor,
transformer_out: torch.Tensor,
) -> torch.Tensor:
B, K, S = sequence.shape
assert (
K == 1
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
assert (
S == 1
), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
assert (
transformer_out.shape[1] == 1
), "Transformer out should be a for a single step."
last_token_input: tp.Optional[torch.Tensor] = None
depformer_input = transformer_out
if self.depformer_multi_linear:
depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
else:
depformer_input = self.depformer_in[0](depformer_input)
if depformer_cb_index == 0:
last_token_input = self.depformer_text_emb(sequence[:, 0])
else:
last_token_input = self.depformer_emb[depformer_cb_index - 1](
sequence[:, 0]
)
depformer_input = depformer_input + last_token_input
assert depformer_input.shape[1] == 1
# depformer_input is [B, 1, depformer_dim].
# The streaming state of the depformer ensures that the proper layer is run.
dep_output = self.depformer(depformer_input)
logits = self.linears[depformer_cb_index](dep_output)
logits = logits[:, None]
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
return logits
Due diligence
Topic
The PyTorch implementation
Question
Hello,
Thank you for your extraordinary work. I have one question regarding to the implementation of your depth transformer network. After examining the code (below), I can see that the depth transformer only accept as input a sequence of length 1, and yield the logit of the token corresponding to the next codebook. As the input sequence length is limited to 1, there is no KV cache involved, in which case in the self attention layer the input token can only attend to itself. As a result, the output of the self attention layer is just the Value, which is nothing more than the matmul of the input vector and the Value matrix.