Closed ostix360 closed 1 month ago
The n_qk_heads
parameter represents the number of query (C
) and key (B
) projection heads, while n_v_heads
refers to the number of value (X
) projection heads. This notation is aligned with the State Space Duality described in Mamba2.
In this context, controlling the number of query and key heads dependently leads to a multi-value architecture, rather than a multi-query one. Note that n_qk_heads
(query-key heads) is set to align with num_key_value_heads
, which tries to match QK to KV projections, but they are not inherently compatible.
I would suggest that:
Regarding the high loss, make sure you’re using our latest matrix mixer implementation and that everything is properly aligned. I’d also suggest starting with distilling one Transformer into another as a quick sanity check. This kind of distillation is usually fast and straightforward.
@ostix360 also I would make sure to freeze the rest of the student weights
Thanks for your quick answer. Indeed, I forgot to update the materialize_mixer and that's what causes the error. And Yes I took care of freezing the rest of the weights. Ok so only the pure multi head attention will work with your implementation otherwise I should modify the mamba architectures to make it compatible or let's say equivalent to the grouped query attention transformers?
yep
Hi, Thanks for your work. I want to create a general script that can transform any transformer model into an hybrid version. I'm struggling to get correct loss value for the stage 1 and I don't think the model learns... I almost copy pasted your code (training and discrete mamba class). Here is an example loss I got:
Here is the training loop:
The only thing that I'm not so sure is, what you meant by
n_qk_heads
andn_v_heads
. I'm trying to convert small qwen model (the 2.0 and now the 2.5) 1.5B and in the config file there is:and if I set
n_qk_heads = num_key_value_heads
andn_v_heads = num_attention_heads
I get an assertion error: assertA_log.shape == (batch_size, length, n_heads)
So I decided to set both to thenum_attention_heads
Thanks in advance!