alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

training functions? #2

Closed win10ogod closed 8 months ago

win10ogod commented 9 months ago

Hello, does this project have similar training functions to llama2.c?

alxndrTL commented 9 months ago

Hello, no, for now you have to train it "manually", there isnt a training script (you can see an example in examples/example_e2e_training.ipynb) I might add a training script like llama2.c in the near future

HaileyStorm commented 8 months ago

I created some training scripts. They're centered around this ChessGPT repo: https://github.com/adamkarvonen/nanoGPT, and they're kinda a mess, so I don't want to submit them as a PR but they should make for a good start.

train.py works with bin files and a fixed (training time) sequence length.

train_bygame.py works with parquet files. It assumes the data is one sequence per row and sorted by sequence length, then split into a files. It then randomly reads from these files, and the sequence length for the training iteration is the max length of the sequences in the batch (or max_seq_len, which is there to cap VRAM use). I did this instead of just df.sample from the complete dataset because the rapid changing of sequence length was causing crashes ... this way you get the speed benefit of only using the sequence length you need, without the instability. train.zip

alxndrTL commented 8 months ago

Thanks, I will take a look !

alxndrTL commented 8 months ago

Hello, if anyone is interested about a full-fledged training script, you can check the othello_mamba repo. It features a complete training script (similar to llama2.c) that you can easily adapt to your needs. It is compatible with mamba.py (it doesn't use mamba_lm.py but a more general lm.py that also works for Transformers)