jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
529 stars 33 forks source link

Lightseq #4

Closed jzhang38 closed 3 months ago

jzhang38 commented 3 months ago

There are still some errors.

jzhang38 commented 3 months ago

accelerate launch \ --config_file accelerate_configs/single_node.yaml \ train.py \ --batch-size 1 \ --gradient-accumulate-every 1 \ --output-dir ./output/dist_attn_debug \ --wandb EasyContext \ --max-train-steps 100 \ --learning-rate 2e-5 \ --dataset yaofu_data \ --model meta-llama/Llama-2-7b-hf \ --seq-length 64000 \ --rope-theta 1000000 \ --dist_flash_attention

accelerate launch \ --config_file accelerate_configs/single_node.yaml \ train.py \ --batch-size 1 \ --gradient-accumulate-every 1 \ --output-dir ./output/ring_attn_debug \ --wandb EasyContext \ --max-train-steps 100 \ --learning-rate 2e-5 \ --dataset yaofu_data \ --model meta-llama/Llama-2-7b-hf \ --seq-length 64000 \ --rope-theta 1000000 \ --ring_attention

accelerate launch --num_processes 8 --config_file accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 eval_needle.py \ --model ./output/dist_attn_debug \ --max_context_length 5000 \ --min_context_length 500 \ --context_interval 500 \ --depth_interval 0.1 \ --num_samples 1 \ --rnd_number_digits 7 \ --haystack_dir PaulGrahamEssays

accelerate launch --num_processes 8 --config_file accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 eval_needle.py \ --model output/ring_attn_debug \ --max_context_length 5000 \ --min_context_length 500 \ --context_interval 500 \ --depth_interval 0.1 \ --num_samples 1 \ --rnd_number_digits 7 \ --haystack_dir PaulGrahamEssays

ring flash attention: | 1/100 [00:20<34:10, 20.71s/it, loss=5.55, ppl=257] | 2/100 [00:34<27:01, 16.55s/it, loss=4.39, ppl=80.3 | 3/100 [00:46<23:44, 14.68s/it, loss=tensor(4.2970, device='cuda:0'), ppl=73.5 4/100 [00:59<22:15, 13.92s/it, loss=tensor(3.9672, device='cuda:0'), ppl=52.8 5/100 [01:12<21:24, 13.52s/it, loss=tensor(3.3602, device='cuda:0'), ppl=28.8] | 6/100 [01:25<20:46, 13.27s/it, loss=tensor(2.7400, device='cuda:0'), ppl=15. | 7/100 [01:37<20:18, 13.10s/it, loss=tensor(2.0927, device='cuda:0'), ppl=8.11] | 8/100 [01:50<19:53, 12.97s/it, loss=tensor(2.4682, device='cuda:0'), ppl=11.8] dist flash attention | 1/100 [00:22<36:27, 22.09s/it, loss=5.55, ppl=257] 2/100 [00:39<32:02, 19.62s/it, loss=4.17, ppl=64.9] | | 3/100 [00:51<26:39, 16.49s/it, loss=3.9, ppl=49.3] 4/100 [01:09<26:07, 16.32s/it, loss=3.84, ppl=46.4] 5/100 [01:23<24:43, 15.62s/it, loss=3.45, ppl=31.5] 6/100 [01:38<23:43, 15.14s/it, loss=2.58, ppl=13.2] |7/100 [01:52<23:02, 14.86s/it, loss=2.85, ppl=17.3] | 8/100 [02:06<22:36, 14.75s/it, loss=2.11, ppl=8.27] | 9/100 [02:21<22:07, 14.59s/it, loss=2.3, ppl=10] | 10/100 [02:35<21:47, 14.53s/it, loss=2.04, ppl=7.68]

jzhang38 commented 3 months ago

image ring flash attention's result image dist attention's result!

Seems that dist attention is better than zhu zilin's ring flash attention implementation. probably because "There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones"

https://github.com/zhuzilin/ring-flash-attention/blob/55ff66fd35f329dfcc24ce7a448bfdd532865966/README.md?plain=1#L38

jzhang38 commented 3 months ago
Screenshot 2024-04-06 at 8 28 46 PM

Interesting

jzhang38 commented 3 months ago

Turns out I've written a bug for dist flash attn. Now it is fixed. See https://github.com/jzhang38/EasyContext/pull/5

kir152 commented 2 months ago

have you tested lightseq on more than one node? i tried running it on two nodes a few weeks back, it was not working.