huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.45k stars 27.11k forks source link

Mamba2 `torch_forward` reduction dimension possibly incorrect? #34817

Open HanGuo97 opened 5 days ago

HanGuo97 commented 5 days ago

System Info

NA

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

NA

Expected behavior

In the torch_forward part of Mamba2, it seems like the reduction dimension should be dim=3 instead of dim=2?

https://github.com/huggingface/transformers/blob/30335093276212ce74938bdfd85bfd5df31a668a/src/transformers/models/mamba2/modeling_mamba2.py#L560

with dim=3, the output seems to more or less match that of Mamba-2's ssd_minimal implementation, but not with dim=2

vasqu commented 5 days ago

Yes, that seems correct. Good spotting! I have a fairly extended ramble into this below (ignore if its too much :) cc @molbap

We can also see it based on the einsum notation in the ssd minimal script: bhzc,bchpn->bzhpn

In this case as we do not use the einsum notation I will notate the same dimension notations before the sum and broadcasted (via none) values as simple 1: decay chunk: bhzc11 permuted states: bh1cpn So based on the multiplication before the sum we get bhzcpn and since we wanted shape bzhpn we need to sum along c (on dim=3) and reshape afterwards.

Just a quick idea: I'm not sure if we even have to reshape twice instead of once by reshaping the decay chunks only (not checked): states: bc1hpn permuted decay chunks: bczh11 Resulting in bczhpn and finally to bzhpn (after sum on dim=1) - hence we avoid the double permutation and just do it "once".

vasqu commented 5 days ago

I'm a bit suprised that the following operations after that don't fail. Have you tested your fixed version on a forward?

HanGuo97 commented 5 days ago

As far as I remember, the following operations won't fail because the reductions was on the number of (source) chunks even though it should be on the number of (target) chunk. During training, these two are of the same size.

vasqu commented 5 days ago

It's been a while but yea that makes sense. Thx for clarifying!

vasqu commented 5 days ago

A tad late, but I've verified it myself now based on my test and modifying the respective local ssd minimal:

ArthurZucker commented 7 hours ago

Will have a look asap thanks @vasqu !