facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.15k stars 275 forks source link

[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186

Closed chrisxcai closed 3 months ago

chrisxcai commented 3 months ago

Avoid extra memory usage caused by concat(), directly allocate flattened fp32 grads and perform fp32 grad accumulation for individual parameters on specific slice within the flattened tensor.

Local test

Deterministic numerical test

baseline, optimize_backward_concat=False

 NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 tra
in.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=8192 --gpu_check_level=-1 --steps=5 --log_all_ste
ps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1 --reshard_after_forward=False --batch_size=128 --m
odel.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parall
el=True --mem_snapshot_stop_step 5 --log_all_steps=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_opti
m=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=100000 --model.use_fp8=True

https://www.internalfb.com/intern/paste/P1404601998/

AVG loss: 10.8708400726318359

optimize_backward_concat=True

https://www.internalfb.com/intern/paste/P1404700768/

AVG loss: 10.8708400726318359

memory usage

baseline, optimize_backward_concat=False

NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/
torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1
 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 --
model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=500000 --model.use_fp8=False

https://www.internalfb.com/intern/paste/P1404611094/

torch_cuda_max_reserved: 15.1GB

optimize_backward_concat=True, before optimization

NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/
torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1
 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 --
model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=True --mem_snapshot_max_entries=500000 --model.use_fp8=False

https://www.internalfb.com/intern/paste/P1404620340/

torch_cuda_max_reserved: 17.4GB

optimize_backward_concat=True, after optimization

https://www.internalfb.com/intern/paste/P1404655599/

torch_cuda_max_reserved: 15.1GB (-13.2%)

E2E MAST

model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP=4, CP = 1, seq_len=1024

baseline, optimize_backward_concat=False https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-c52vf7

Screenshot 2024-06-09 at 5 14 34 PM

tflops/s = ~382

Screenshot 2024-06-09 at 5 15 31 PM

trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.1149070831916.json.gz&bucket=acadia

optimize_backward_concat=True before optimization

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-pdtcx1d5

Screenshot 2024-06-09 at 5 16 54 PM

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.24449323379469.json.gz&bucket=acadia

optimize_backward_concat=True after optimization

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-ghg1f57z

Screenshot 2024-06-09 at 5 18 28 PM

tflops/s = ~440 (+15%)

Screenshot 2024-06-09 at 5 19 03 PM

trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.17125783820625.json.gz&bucket=acadia