pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

[7/x] make profiling script support Float8Linear dynamic scaling #298

Closed vkuzo closed 2 weeks ago

vkuzo commented 2 weeks ago

Stack from ghstack (oldest at bottom):

Summary:

Run with relevant settings and verify:

  1. performance of Float8Linear with dynamic scaling is very close to Float8DynamicLinear
  2. if we start with all delayed scaling and gradually turn on dynamic scaling tensor by tensor, performance decreases and approaches that of (1)

Test Plan:

python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type norm_ffn_norm --linear_type delayed

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
category                                              
0_gemm        14.426     7.993       0.554       1.805
1_f8_overhead  0.000     2.004         inf       0.000
2_other        3.078     2.168       0.704       1.420
All           17.504    12.164       0.695       1.439

python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type norm_ffn_norm --linear_type delayed --scaling_type_x dynamic

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
 category
 0_gemm        14.478     7.991       0.552       1.812
 1_f8_overhead  0.000     2.099         inf       0.000
 2_other        3.069     2.200       0.717       1.395
 All           17.547    12.289       0.700       1.428

python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type norm_ffn_norm --linear_type delayed --scaling_type_x dynamic --scaling_type_w dynamic

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
 category
 0_gemm        14.524     7.995       0.550       1.817
 1_f8_overhead  0.000     2.099         inf       0.000
 2_other        3.045     2.210       0.726       1.378
 All           17.569    12.303       0.700       1.428

python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type norm_ffn_norm --linear_type delayed --scaling_type_x dynamic --scaling_type_w dynamic --scaling_type_dL_dY dynamic

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
 category
 0_gemm        14.460     7.991       0.553       1.810
 1_f8_overhead  0.000     2.399         inf       0.000
 2_other        3.090     2.199       0.712       1.405
 All           17.550    12.588       0.717       1.394

python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type norm_ffn_norm --linear_type dynamic

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
 category
 0_gemm        14.463     7.954       0.550       1.818
 1_f8_overhead  0.000     2.401         inf       0.000
 2_other        3.071     2.213       0.721       1.388
 All           17.534    12.568       0.717       1.395

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59305795

vkuzo commented 2 weeks ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 2 weeks ago

This pull request has been merged in pytorch-labs/float8_experimental@7a1bdabdf8c7073f5020c1e0048340a4627004d5.