microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

about the longnet's ppl #95

Open robotzheng opened 6 months ago

robotzheng commented 6 months ago

torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [1024,2048,4096] --dilated-ratio [1,2,4] the best ppl on the val dataset is 29.11 torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [2048,4096] --dilated-ratio [1,2] the best ppl on the val dataset is 28.08 torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [4096] --dilated-ratio [1] the best ppl on the val dataset is 27.92

more dilation more bad ppl,that is not same as the paper. the speeds on the above are almost same, that is also not same as the paper.

shumingma commented 6 months ago

We didn't claim that more dilation is better (thinking about an extreme case that the segment length starts from 1). We suggest the segment length not less than 2048 in language tasks, for a good balance of efficiency and accuracy. In the paper, we use segment lengths of {2048, 4096, 8192, 16384, 32768}.

As shown in Figure 5 in our paper, the speedup is significant only when the sequence length is greater than 8K. This is because the flash-attn has an excellent optimization for dense attention, so the percentage of the attention cost is pretty small when the sequence length is relatively short. In this case, any further optimization on the attention becomes invisible. image

robotzheng commented 6 months ago

Thanks for your answer, I will make more expriments.