Closed LeonG7 closed 8 months ago
+1
使用 zero3 试试
使用 zero3 试试
zero3不报错,但是一直卡着不动
我用8*A800 80G 跑DPO lora训练qwen-7b-chat都把显存吃满了,RLHF lora都才只用了一半左右 很奇怪(两个都用的是zero2)
之前跑的是14b的模型,估计是模型放不下,换成7b的模型全参DPO能跑了
我的是tinyllama 1.2B的,训练用的DPO lora,A100-80G单卡,SFT和DPO 训练,基本按照readme操作的,数据用的readme上提供的,通过nvidia-smi 发现没有吃显卡内存,怎么回事儿? 脚本如下: CUDA_VISIBLE_DEVICES=7 python3 src/train_bash.py \ --stage dpo \ --do_train \ --model_name_or_path tinyllama_base \ --adapter_name_or_path ./output/tinyllama_sft/checkpoint-1500 \ --create_new_adapter \ --dataset comparison_gpt4_zh \ --template llama2_zh \ --finetuning_type lora \ --lora_target q_proj,v_proj \ --output_dir ./output/tinyllama_dpo \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --save_steps 500 \ --learning_rate 1e-5 \ --num_train_epochs 1.0 \ --plot_loss \ --overwrite_output_dir true \ --bf16
Reminder
Reproduction
ds_config_zero2.json
Expected behavior
No response
System Info
transformers 4.36.2 transformers-stream-generator 0.0.4 pytorch-quantization 2.1.2 torch 2.2.0 torch-tensorrt 1.4.0.dev0 torchdata 0.7.1 torchtext 0.17.0 torchvision 0.17.0 python 3.8.16 (3.10版本也试过,都会报错)
Others
No response