Open AstroCIEL opened 1 month ago
I would be surprised if mamba-minimal actually produced output that is bit-level equivalent to the reference implementation. In general, floating point calculations are very sensitive to the tiniest of changes, even different patch versions of PyTorch would give different results due to new optimizations.
It is really awesome work for understanding the computing pattern of mamba2 model. But i notice that you mentioned that
Could you please give some hint on why the outputs are not numerically equivalent? Is there any modification or difference of model architecture or computing pattern with the reference implementation? If not, what caused different output? Since in John Ma's mamba-minimal repo it mentioned that it has equivalent output logits as the original mamba, i wonder why this inplemention would lead to non-equivalent output logits. Looking forward to your kind answer.