TemporaryLoRA / Temp-LoRA

88 stars 7 forks source link

尝试在RTX6000 48G显存机器上运行,出现OOM问题。运行这个项目需要多少显存? #2

Closed SlothRan closed 6 months ago

SlothRan commented 6 months ago

在运行“bash scripts/llama2.sh ”显式如下: image 配置信息如下: image image image

TemporaryLoRA commented 6 months ago
刚重新做了测试,结果如下: script enable gradient checkpointing mem
scripts/llama2.sh True 22GB
scripts/llama2.sh False 56GB
scripts/example.sh True 24GB (step = 2)
scripts/example.sh False 36GB (step = 2)

实验环境:A800, cuda11.8,mem由nvidia-smi命令得到。 实验配置为本项目提供的官方配置。

注意,"model.generate"过程中,transformers会自动进行一些显存的回收、释放,因此显存会有较大的波动。

启动gradient_checkpointing的方法,以scripts/llama2.sh为例,将最后的"--gradient_checkpointing" 设为 "true' 即可。

ACCELERATE_CONFIG=""
SAVE_DIR=""

MODEL_NAME="togethercomputer/LLaMA-2-7B-32K"

mkdir -p $SAVE_DIR
accelerate launch --config_file $ACCELERATE_CONFIG trainer/acc_pg19_trainer.py --model_name $MODEL_NAME \
  ……
  --gradient_checkpointing "true" 

麻烦您看下开启"gradient_checkpointing"后的结果。