Closed gdsaikrishna closed 1 week ago
I recommend to use 4096 RedPajama with a 2048 context length for 2 epochs
. This is because that 2-bit quantization lead significance information loss, and require more data to avoid over-fitting problem.
Section G in Appendix have offered the detailed quantization settings. One example about the command we used to trian w2a16g128 is:
CUDA_VISIBLE_DEVICES=2 python main.py \
--model_path path/to/models/Meta-Llama-3-8B \
--model_name Llama-3-8b \
--output_dir ./log/llama-3-8b-w2g128-rotate-prefix-mseinit \
--eval_ppl \
--set_prefixed_tokens \
--pre_rotate \
--wbits 2 \
--w_group_size 128 \
--w_asym \
--mse_init \
--quant_lr 1e-4 \
--weight_lr 2e-5 \
--min_lr_factor 20 \
--epochs 2 \
--train_size 4096 \
--batch_size 2 \
--training_seqlen 2048 \
--calib_dataset redpajama \
--eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande
To reproduce the Llama3 8B weight-only quantization result, should we use 4096 RedPajama with a 2048 context length for 2 epochs or 512 Pile with a 1k context length for 10 epochs? Which approach would you recommend, and why?