Open HanGuo97 opened 5 days ago
Yes, that seems correct. Good spotting! I have a fairly extended ramble into this below (ignore if its too much :) cc @molbap
We can also see it based on the einsum notation in the ssd minimal script:
bhzc,bchpn->bzhpn
In this case as we do not use the einsum notation I will notate the same dimension notations before the sum and broadcasted (via none) values as simple 1:
decay chunk: bhzc11
permuted states: bh1cpn
So based on the multiplication before the sum we get bhzcpn
and since we wanted shape bzhpn
we need to sum along c
(on dim=3) and reshape afterwards.
Just a quick idea: I'm not sure if we even have to reshape twice instead of once by reshaping the decay chunks only (not checked):
states: bc1hpn
permuted decay chunks: bczh11
Resulting in bczhpn
and finally to bzhpn
(after sum on dim=1) - hence we avoid the double permutation and just do it "once".
I'm a bit suprised that the following operations after that don't fail. Have you tested your fixed version on a forward?
As far as I remember, the following operations won't fail because the reductions was on the number of (source) chunks even though it should be on the number of (target) chunk. During training, these two are of the same size.
It's been a while but yea that makes sense. Thx for clarifying!
A tad late, but I've verified it myself now based on my test and modifying the respective local ssd minimal:
dim=3
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
Will have a look asap thanks @vasqu !
System Info
NA
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
NA
Expected behavior
In the
torch_forward
part of Mamba2, it seems like the reduction dimension should bedim=3
instead ofdim=2
?https://github.com/huggingface/transformers/blob/30335093276212ce74938bdfd85bfd5df31a668a/src/transformers/models/mamba2/modeling_mamba2.py#L560
with
dim=3
, the output seems to more or less match that of Mamba-2'sssd_minimal
implementation, but not withdim=2