Closed chrisxcai closed 3 months ago
Avoid extra memory usage caused by concat(), directly allocate flattened fp32 grads and perform fp32 grad accumulation for individual parameters on specific slice within the flattened tensor.
baseline, optimize_backward_concat=False
NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 tra in.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=8192 --gpu_check_level=-1 --steps=5 --log_all_ste ps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1 --reshard_after_forward=False --batch_size=128 --m odel.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parall el=True --mem_snapshot_stop_step 5 --log_all_steps=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_opti m=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=100000 --model.use_fp8=True
https://www.internalfb.com/intern/paste/P1404601998/
AVG loss: 10.8708400726318359
optimize_backward_concat=True
https://www.internalfb.com/intern/paste/P1404700768/
NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 -- model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=500000 --model.use_fp8=False
https://www.internalfb.com/intern/paste/P1404611094/
torch_cuda_max_reserved: 15.1GB
optimize_backward_concat=True, before optimization
NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 -- model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=True --mem_snapshot_max_entries=500000 --model.use_fp8=False
https://www.internalfb.com/intern/paste/P1404620340/
torch_cuda_max_reserved: 17.4GB
optimize_backward_concat=True, after optimization
https://www.internalfb.com/intern/paste/P1404655599/
torch_cuda_max_reserved: 15.1GB (-13.2%)
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP=4, CP = 1, seq_len=1024
baseline, optimize_backward_concat=False https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-c52vf7
tflops/s = ~382
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.1149070831916.json.gz&bucket=acadia
optimize_backward_concat=True before optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-pdtcx1d5
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.24449323379469.json.gz&bucket=acadia
optimize_backward_concat=True after optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-ghg1f57z
tflops/s = ~440 (+15%)
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.17125783820625.json.gz&bucket=acadia
Avoid extra memory usage caused by concat(), directly allocate flattened fp32 grads and perform fp32 grad accumulation for individual parameters on specific slice within the flattened tensor.
Local test
Deterministic numerical test
baseline, optimize_backward_concat=False
https://www.internalfb.com/intern/paste/P1404601998/
optimize_backward_concat=True
https://www.internalfb.com/intern/paste/P1404700768/
memory usage
baseline, optimize_backward_concat=False
https://www.internalfb.com/intern/paste/P1404611094/
optimize_backward_concat=True, before optimization
https://www.internalfb.com/intern/paste/P1404620340/
optimize_backward_concat=True, after optimization
https://www.internalfb.com/intern/paste/P1404655599/
E2E MAST
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP=4, CP = 1, seq_len=1024
baseline, optimize_backward_concat=False https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-c52vf7
tflops/s = ~382
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.1149070831916.json.gz&bucket=acadia
optimize_backward_concat=True before optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-pdtcx1d5
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.24449323379469.json.gz&bucket=acadia
optimize_backward_concat=True after optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-ghg1f57z
tflops/s = ~440 (+15%)
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.17125783820625.json.gz&bucket=acadia