facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.18k stars 280 forks source link

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180

Open chrisxcai opened 6 months ago

chrisxcai commented 6 months ago

If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync is True (e.g. last microbatch)

Trace comparison

trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module): https://fburl.com/perfdoctor/qdt32ibh

trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module): https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia

numerics verification

local run with deterministic mode

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)

baseline

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=False

https://www.internalfb.com/intern/paste/P1363180533/

test

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=True

https://www.internalfb.com/intern/paste/P1363177870/

TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)

baseline: https://www.internalfb.com/intern/paste/P1322976356/ test : https://www.internalfb.com/intern/paste/P1322871976/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)

baseline https://www.internalfb.com/intern/paste/P1358660231/

test https://www.internalfb.com/intern/paste/P1358659328/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)

baseline https://www.internalfb.com/intern/paste/P1358780690

test https://www.internalfb.com/intern/paste/P1358786994/

E2E MAST tests:

model = small, TP = 2, PP = 2, DP = 2 (loss on par)

baseline: https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd

test: https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966

loss

Perf evaluation

model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8

baseline: e2e TFLOPS/s: 339.53 comp TFLOPS/s: 625.64

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia

test: e2e TFLOPS/s: 387.98 (~15%) comp TFLOPS/s: 817.5 (~30%)

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia