state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.13k stars 1.11k forks source link

Would you post a minimal example of training this? #11

Open freckletonj opened 11 months ago

freckletonj commented 11 months ago

Amazing work, and I'm inspired by the connections to dynamical systems.

Would you mind showing us a minimal example of training or finetuning this?

RevolGMPHL commented 11 months ago

same problem

albertfgu commented 11 months ago

We released just the core model because it can be drop-in replaced for any model of any training/finetuning pipeline, of which there are many. Is there an example application you have in mind?

freckletonj commented 11 months ago

Thanks for the reply, and woh, just any pytorch training setup will do? I'm just interested in next-token prediction.

Does it get along with, say, the accelerate ecosystem for multi-node/multi-gpu? I saw transformers in setup.py, how does that work? I thought this architecture wasn't related?

I assume optimizations like flash attention are no longer relevant?

When you release larger models (fingers-crossed!!!), bitsandbytes will likely become relevant, as well as peft and QLORAs, and DeepSpeed.

But then I'm also curious about some training params, like, LR?, AdamW?, WD?

DiscordJim commented 11 months ago

Agreed, even an example with the HuggingFace Trainer would be lovely. I am running into issues using it with HuggingFace trainer and even with causal language modeling with Transformers without Trainer. Thank you for the incredible work as well, this is amazing.

huseinzol05 commented 11 months ago

https://github.com/state-spaces/mamba/issues/6, i tried deepspeed zero 3 with HF trainer API, looks good.

I added,

  1. cross entropy loss.
  2. Transformers config interface.
  3. Transformers PretrainedModel interface.

The results,

  1. tested to save using safetensors.
  2. load existing checkpoints to continue pretraining.
  3. with 80GB VRAM, maximum batch size is 8 with 4k context length, 1 step took ~300ms.
DiscordJim commented 11 months ago

Just saw your post, great work and tested on my end with similar success.

freckletonj commented 11 months ago

Geez open source is fast, here's a chattified version with simple example: https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py