alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

VisionMamba implementation #29

Closed AliYoussef97 closed 5 months ago

AliYoussef97 commented 5 months ago

Vision Mamba implementation. Forward function contains both forward and backward direction SSM, however, I did not implement the step function as both directions would return a cache while the final output is the combination of both the forward and backward direction.

alxndrTL commented 5 months ago

Thank you for the PR! Have you tried it ?

AliYoussef97 commented 5 months ago

@alxndrTL I did pass dummy tensors to the model to check if there are any errors thrown, but I haven't trained the model on specific task as that would take a bit more time to make training script similar to this, but I wouldn't mind making one when I have more time on my hands.

I might make another commit to this PR later today however to make the code a little bit cleaner, mainly:

I will probably commit those changes in the next hour or so. Commits added.

alxndrTL commented 5 months ago

Thanks! I'll merge this now and add comments in the file as well as in the README

AliYoussef97 commented 5 months ago

Thanks! I'll merge this now and add comments in the file as well as in the README

Great! I’ll try and add a training script and the full trainable model sometime this month, similar to this (adding cls token, pos embedding, random sequence flips, etc. This PR is the Vision Mamba Block only).

The Vision Mamba block is the “v2” here (it gets called from the Block function in the first link), which what this PR is based off.)

alxndrTL commented 4 months ago

Hello, i'm refactoring your code in vim.py, if I understood you correctly, I get get rid of all the things related to the cache right ?

AliYoussef97 commented 4 months ago

That is correct, I did not implement the step function, only the forward function.