state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.86k stars 986 forks source link

bfloat16 overflow during training session #6

Open huseinzol05 opened 7 months ago

huseinzol05 commented 7 months ago
  1. I tried vanilla pytorch training loop using bfloat16, the loss got overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-bf16.ipynb
  2. so I tried vanilla pytorch training loop using fp32, the loss is ok, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-fp32.ipynb
  3. I thought maybe because no gradient clipping and etc, so I tried using HuggingFace trainer with Deepspeed, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-deepspeed-bf16.ipynb, the loss got overflow.
  4. so I removed deepspeed and use fp32 in HuggingFace trainer, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-fp32.ipynb, the loss is ok.

If bfloat16 is not working, deepspeed is not going to work.

tridao commented 7 months ago

We don't use deepspeed, just Pytorch AMP (bf16) to train models. Can you try that?

tridao commented 7 months ago

Model parameters should be in fp32, just like in Pytorch AMP docs.

huseinzol05 commented 7 months ago

Looks good on amp bf16, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-bf16.ipynb, and I saw there is 2.8B checkpoint, what is the batch size and it use 2k context length based on the paper? Is it 2.8B model fit in a single A100 80GB?

tridao commented 7 months ago

The 2.8B uses total batch size of 1M tokens (following GPT3 paper), and seqlen=2k. Activation memory should be around the same as an optimized transformer (e.g. with FlashAttention), so it should fit in a single A100 80GB (you might need to adjust gradient accumulation to fit). For models around 1B-3B on 8xA100s, sharding the optimizer states (e.g. with Pytorch distributed optimizer, equivalent to ZeRO-stage1) will help reduce the amount of memory needed.

huseinzol05 commented 7 months ago

Thanks!

huseinzol05 commented 7 months ago

I tried with Zero2 FP32, sometime not overflow, sometime overflow, after that tried 1e-6 lr and 0.1 gradient norm, no longer overflow but 0.1 gradient norm is too low for pretraining.

huseinzol05 commented 7 months ago

After properly wrapped the model with transformers PreTrainedModel and use 1.4B, surprisingly no more overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-1.4b-trainer-deepspeed3-bf16.ipynb

  1. tested to save using safetensors.
  2. load existing checkpoints to continue pretraining.
  3. with 80GB VRAM, maximum batch size is 8 with 4k context length, 1 step took ~300ms.
albertfgu commented 7 months ago

Thanks for your explorations! I added a little warning to the README about the dtype, but it's quite useful for people to post their observations here. We've actually never seen these types of instabilities during training; as Tri said, we just use native PyTorch AMP.

binxuan commented 7 months ago

I am using pytorch's FSDP with bf16 for training. Looks like I encountered similar issue with NaN loss.

geronimi73 commented 7 months ago

I am using pytorch's FSDP

what is your auto_wrap_policy? if you don't mind sharing it @binxuan

btrude commented 6 months ago

I am using pytorch's FSDP

what is your auto_wrap_policy? if you don't mind sharing it @binxuan

This worked for me:

from mamba_ssm.modules.mamba_simple import Block
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block,},
)

I was also able to pretrain a 200m mamba lm on 12b tokens yesterday in bf16 without issue. Testing fsdp + bf16 right now and everything works as expected as well.

lwang2070 commented 6 months ago

Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable

Loss on hg38

Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?

gpantaz commented 6 months ago

Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable

Loss on hg38

Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?

I also have observed similar instabilities with using PL trainer. In my case I am creating custom embeddings which I concantenate along with the text embeddings from mamba before feeding them to the mixer model. I have tried lowering the learning rate, disabling AMP but my loss goes to nan values.

apoorv2904 commented 5 months ago

@gpantaz and @lwang2070 were you able to fix the issue of bf16 or fp16 training with PL? I also see training issues for my model.

gpantaz commented 5 months ago

No sadly :/