Closed jzhang38 closed 3 months ago
accelerate launch \ --config_file accelerate_configs/single_node.yaml \ train.py \ --batch-size 1 \ --gradient-accumulate-every 1 \ --output-dir ./output/dist_attn_debug \ --wandb EasyContext \ --max-train-steps 100 \ --learning-rate 2e-5 \ --dataset yaofu_data \ --model meta-llama/Llama-2-7b-hf \ --seq-length 64000 \ --rope-theta 1000000 \ --dist_flash_attention
accelerate launch \ --config_file accelerate_configs/single_node.yaml \ train.py \ --batch-size 1 \ --gradient-accumulate-every 1 \ --output-dir ./output/ring_attn_debug \ --wandb EasyContext \ --max-train-steps 100 \ --learning-rate 2e-5 \ --dataset yaofu_data \ --model meta-llama/Llama-2-7b-hf \ --seq-length 64000 \ --rope-theta 1000000 \ --ring_attention
accelerate launch --num_processes 8 --config_file accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 eval_needle.py \ --model ./output/dist_attn_debug \ --max_context_length 5000 \ --min_context_length 500 \ --context_interval 500 \ --depth_interval 0.1 \ --num_samples 1 \ --rnd_number_digits 7 \ --haystack_dir PaulGrahamEssays
accelerate launch --num_processes 8 --config_file accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 eval_needle.py \ --model output/ring_attn_debug \ --max_context_length 5000 \ --min_context_length 500 \ --context_interval 500 \ --depth_interval 0.1 \ --num_samples 1 \ --rnd_number_digits 7 \ --haystack_dir PaulGrahamEssays
ring flash attention: | 1/100 [00:20<34:10, 20.71s/it, loss=5.55, ppl=257] | 2/100 [00:34<27:01, 16.55s/it, loss=4.39, ppl=80.3 | 3/100 [00:46<23:44, 14.68s/it, loss=tensor(4.2970, device='cuda:0'), ppl=73.5 4/100 [00:59<22:15, 13.92s/it, loss=tensor(3.9672, device='cuda:0'), ppl=52.8 5/100 [01:12<21:24, 13.52s/it, loss=tensor(3.3602, device='cuda:0'), ppl=28.8] | 6/100 [01:25<20:46, 13.27s/it, loss=tensor(2.7400, device='cuda:0'), ppl=15. | 7/100 [01:37<20:18, 13.10s/it, loss=tensor(2.0927, device='cuda:0'), ppl=8.11] | 8/100 [01:50<19:53, 12.97s/it, loss=tensor(2.4682, device='cuda:0'), ppl=11.8] dist flash attention | 1/100 [00:22<36:27, 22.09s/it, loss=5.55, ppl=257] 2/100 [00:39<32:02, 19.62s/it, loss=4.17, ppl=64.9] | | 3/100 [00:51<26:39, 16.49s/it, loss=3.9, ppl=49.3] 4/100 [01:09<26:07, 16.32s/it, loss=3.84, ppl=46.4] 5/100 [01:23<24:43, 15.62s/it, loss=3.45, ppl=31.5] 6/100 [01:38<23:43, 15.14s/it, loss=2.58, ppl=13.2] |7/100 [01:52<23:02, 14.86s/it, loss=2.85, ppl=17.3] | 8/100 [02:06<22:36, 14.75s/it, loss=2.11, ppl=8.27] | 9/100 [02:21<22:07, 14.59s/it, loss=2.3, ppl=10] | 10/100 [02:35<21:47, 14.53s/it, loss=2.04, ppl=7.68]
ring flash attention's result
dist attention's result!
Seems that dist attention is better than zhu zilin's ring flash attention implementation. probably because "There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones"
Interesting
Turns out I've written a bug for dist flash attn. Now it is fixed. See https://github.com/jzhang38/EasyContext/pull/5
have you tested lightseq on more than one node? i tried running it on two nodes a few weeks back, it was not working.
There are still some errors.