pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

QOL improvements to linear benchmarking script #278

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

  1. add more command line filters
  2. add dynamic scaling
  3. remove float16 since it's low-pri and this cuts down benchmark time by 50%

Test Plan:

> python benchamrks/benchmark/linear_float8.py -o ~/local/tmp/test.txt

         name                shape linear_type  compiled  use_fast_accum  ref_time_sec  pt_fp8_time_sec  pt_fp8_speedup
0   attn.wqkv  (16384, 8192, 1280)     delayed     False            True      0.002154         0.004543        0.474068
1   attn.wqkv  (16384, 8192, 1280)     dynamic     False            True      0.002157         0.003988        0.540952
2     attn.w0  (16384, 1024, 8192)     delayed     False            True      0.001894         0.003842        0.492965
3     attn.w0  (16384, 1024, 8192)     dynamic     False            True      0.001894         0.003325        0.569775
4     ffn.w13  (16384, 8192, 7168)     delayed     False            True      0.010490         0.011734        0.893940
5     ffn.w13  (16384, 8192, 7168)     dynamic     False            True      0.010505         0.010818        0.970985
6      ffn.w2  (16384, 3584, 8192)     delayed     False            True      0.005433         0.007103        0.764812
7      ffn.w2  (16384, 3584, 8192)     dynamic     False            True      0.005504         0.006452        0.853085
8   attn.wqkv  (16384, 8192, 1280)     delayed     False           False      0.002167         0.004620        0.469059
9   attn.wqkv  (16384, 8192, 1280)     dynamic     False           False      0.002166         0.004044        0.535681
10    attn.w0  (16384, 1024, 8192)     delayed     False           False      0.001905         0.003920        0.485904
11    attn.w0  (16384, 1024, 8192)     dynamic     False           False      0.001907         0.003401        0.560786
12    ffn.w13  (16384, 8192, 7168)     delayed     False           False      0.010463         0.011951        0.875500
13    ffn.w13  (16384, 8192, 7168)     dynamic     False           False      0.010472         0.011011        0.951083
14     ffn.w2  (16384, 3584, 8192)     delayed     False           False      0.005453         0.007254        0.751663
15     ffn.w2  (16384, 3584, 8192)     dynamic     False           False      0.005416         0.006568        0.824689

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D58396927

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

y-sq commented 3 months ago

Shall we consider to use (or add an option to use) do_bench_using_profiling for benchmarking, which only counts the GPU kernel time.

vkuzo commented 3 months ago

Shall we consider to use (or add an option to use) do_bench_using_profiling for benchmarking, which only counts the GPU kernel time.

I'm open to it, maybe in a separate PR?

facebook-github-bot commented 3 months ago

This pull request has been merged in pytorch-labs/float8_experimental@1e9add319830be21520333c146232f9c0670b16c.