OpenGVLab / LAMM

[NeurIPS 2023 Datasets and Benchmarks Track] LAMM: Multi-Modal Large Language Models and Applications as AI Agents
https://openlamm.github.io/
286 stars 15 forks source link

add flash attention support in training to save memory and speed up #35

Closed lighten001 closed 11 months ago

lighten001 commented 1 year ago
  1. add src/model/flash_attn_patch.py
  2. add "--use_flash_attn" arg in train.py to replace LlamaAttention.forward and LlamaModel._prepare_decoder_attention_mask
  3. when using flash attention, llama's use_cache should be False