RulinShao / LightSeq

Official repository for LightSeq: Sequence Level Parallelism for Distributed Training of Long Context Transformers
179 stars 8 forks source link

DistFlashAttention: Distributed Memory-efficient Attention for Long-context LLMs Training

Official repository for COLM 2024 paper "DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training". DISTFLASHATTN, also named LightSeq, achieves up to 2x faster, 2-8x longer sequences vs Megatron-LM on 16 80GB A100s.

Paper: https://arxiv.org/pdf/2310.03294.pdf

News

DistFlashAttention Implementation

Usage

We provide an example to use lightseq in training. For example, to run LightSeq on 8 nodes, replace the data_path with your own dataset and run

python -m torch.distributed.run --nproc_per_node=8 \
         lightseq/train_lightseq_no_trainer.py \
        --model_name_or_path Llama-2-7b-chat-hf \
        --data_path <your_dataset>.pkl \
        --bf16 \
        --output_dir outputs \
        --num_train_epochs 3    \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 4  \
        --gradient_accumulation_steps 1 \
        --evaluation_strategy no \
        --save_strategy steps \
        --save_steps 1000  \
        --save_total_limit 1 \
        --learning_rate 2e-5 \
        --weight_decay 0.  \
        --warmup_ratio 0.03  \
        --lr_scheduler_type "cosine" \
        --logging_steps 1  \
        --fsdp "full_shard auto_wrap" \
        --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
        --tf32 True  \
        --model_max_length 16384  \
        --gradient_checkpointing True  \
        --lazy_preprocess True

Rematerialization-aware gradient checkpointing

Rematerialization-aware gradient checkpointing can save your training time in one line! We release it as a Python package, fastckpt, so you install it by

pip install fastckpt

To replace both HF checkpointing with FashCkpt and HF LlamaAttention with FlashAttention, run

# import fastckpt before importing transformers
from fastckpt.llama_flash_attn_ckpt_monkey_patch import replace_hf_ckpt_with_fast_ckpt
replace_hf_ckpt_with_fast_ckpt()

# import transformers and other packages
import transformers
...

Alternatively, if you only want to replace the attention module with FlashAttention, simply run

# import fastckpt before importing transformers
from fastckpt.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

# import transformers and other packages
import transformers
...

If you find this repo useful, please cite

@article{li2023lightseq,
  title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
  author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
  journal={arXiv preprint arXiv:2310.03294},
  year={2023}
}