Open huseinzol05 opened 7 months ago
We don't use deepspeed, just Pytorch AMP (bf16) to train models. Can you try that?
Looks good on amp bf16, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-bf16.ipynb, and I saw there is 2.8B checkpoint, what is the batch size and it use 2k context length based on the paper? Is it 2.8B model fit in a single A100 80GB?
The 2.8B uses total batch size of 1M tokens (following GPT3 paper), and seqlen=2k. Activation memory should be around the same as an optimized transformer (e.g. with FlashAttention), so it should fit in a single A100 80GB (you might need to adjust gradient accumulation to fit). For models around 1B-3B on 8xA100s, sharding the optimizer states (e.g. with Pytorch distributed optimizer, equivalent to ZeRO-stage1) will help reduce the amount of memory needed.
Thanks!
I tried with Zero2 FP32, sometime not overflow, sometime overflow, after that tried 1e-6 lr and 0.1 gradient norm, no longer overflow but 0.1 gradient norm is too low for pretraining.
After properly wrapped the model with transformers PreTrainedModel
and use 1.4B, surprisingly no more overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-1.4b-trainer-deepspeed3-bf16.ipynb
Thanks for your explorations! I added a little warning to the README about the dtype, but it's quite useful for people to post their observations here. We've actually never seen these types of instabilities during training; as Tri said, we just use native PyTorch AMP.
I am using pytorch's FSDP with bf16 for training. Looks like I encountered similar issue with NaN loss.
I am using pytorch's FSDP
what is your auto_wrap_policy
? if you don't mind sharing it
@binxuan
I am using pytorch's FSDP
what is your
auto_wrap_policy
? if you don't mind sharing it @binxuan
This worked for me:
from mamba_ssm.modules.mamba_simple import Block
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Block,},
)
I was also able to pretrain a 200m mamba lm on 12b tokens yesterday in bf16 without issue. Testing fsdp + bf16 right now and everything works as expected as well.
Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable
Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?
Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable
Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?
I also have observed similar instabilities with using PL trainer. In my case I am creating custom embeddings which I concantenate along with the text embeddings from mamba before feeding them to the mixer model. I have tried lowering the learning rate, disabling AMP but my loss goes to nan values.
@gpantaz and @lwang2070 were you able to fix the issue of bf16 or fp16 training with PL? I also see training issues for my model.
No sadly :/
If bfloat16 is not working, deepspeed is not going to work.