mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.83k stars 501 forks source link

MPT training with ALiBi and Flash Attention 2 #1289

Open rickgit16 opened 2 weeks ago

rickgit16 commented 2 weeks ago

I am trying to pretrain a MPT model using llm-foundry using AliBi with flash attention. During pre training, I see the below warning -

WARNING: composer.algorithms.alibi.alibi: ALiBi had no effect on the model! Support for ALiBi surgery is currently limited to the following classes: 
    transformers.models.bert.modeling_bert.BertEmbeddings
    transformers.models.bert.modeling_bert.BertSelfAttention
    transformers.models.gpt2.modeling_gpt2.GPT2Attention
    transformers.models.gpt2.modeling_gpt2.GPT2Model
    transformers.models.roberta.modeling_roberta.RobertaEmbeddings
    transformers.models.roberta.modeling_roberta.RobertaSelfAttention

I have followed PR#820 for alibi with FA2 for setup, and have used the following in pretrain yaml file -

model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 1024
  n_heads: 16
  n_layers: 24
  expansion_ratio: 4
  max_seq_len: 2048
  vocab_size: 50368
  loss_fn: torch_crossentropy
  attn_config:
    attn_impl: flash

algorithms:
  alibi:
    max_sequence_length: 2048

Just to confirm alibi hasn't been used, I had converted the composer checkpoint to a HF one using scripts/inference/convert_composer_to_hf.py. I find the attn_config.alibi flag is set to False in the config.json file.

Some insights and direction on how to use alibi with flash attention 2 would be immensely helpful.

dakinggg commented 1 week ago

Hi, to turn on alibi in MPT, you'll want to not use the algorithm approach, but just specify it directly in the model architecture. Here is an example: https://github.com/mosaicml/llm-foundry/blob/c23be4ab9e146ff1064758a83fbe57c7d7a8e2ba/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support

rickgit16 commented 1 week ago

Hi @dakinggg, thank you for reference. Do we still need to follow PR#820 for the setup?

dakinggg commented 1 week ago

Which part of that PR are you referring to? Just installing pip install .[gpu] and specifying attn_impl: flash should work fine