pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

Add rowwise scaling to Float8Inference module #305

Open drisspg opened 3 months ago

drisspg commented 3 months ago

Summary

Performance

Comparison Results: +--------------------------+-------------+-------------------+---------------+ | Variant | Time (μs) | Speedup vs BF16 | MAE vs BF16 | +==========================+=============+===================+===============+ | BF16 | 2540.56 | 1.00x | 0 | +--------------------------+-------------+-------------------+---------------+ | FP8 Dynamic | 1512.96 | 1.68x | 0.00543213 | +--------------------------+-------------+-------------------+---------------+ | FP8 Static | 1363.75 | 1.86x | 0.00546265 | +--------------------------+-------------+-------------------+---------------+ | FP8 Weight Only | 2774.22 | 0.92x | 0.00379944 | +--------------------------+-------------+-------------------+---------------+ | FP8 Dynamic AxisWise | 1510.82 | 1.68x | 0.00543213 | +--------------------------+-------------+-------------------+---------------+ | FP8 Static AxisWise | 1438.92 | 1.77x | 0.00546265 | +--------------------------+-------------+-------------------+---------------+ | FP8 Weight Only AxisWise | 2762.88 | 0.92x | 0.00379944 | +--------------------------+-------------+-------------------+---------------+


### Numerics

Using this https://github.com/pytorch/ao/pull/446
TensorWise Dynamic scaling:

``` Shell
+------------+--------------------------------------------+
| Task       | Metrics                                    |
+============+============================================+
| winogrande | +-----------------+----------+             |
|            | | acc,none        | 0.735596 |             |
|            | +-----------------+----------+             |
|            | | acc_stderr,none | 0.012395 |             |
|            | +-----------------+----------+             |
+------------+--------------------------------------------+
| wikitext   | +-----------------------------+----------+ |
|            | | bits_per_byte,none          | 0.538637 | |
|            | +-----------------------------+----------+ |
|            | | bits_per_byte_stderr,none   | N/A      | |
|            | +-----------------------------+----------+ |
|            | | byte_perplexity,none        | 1.452600 | |
|            | +-----------------------------+----------+ |
|            | | byte_perplexity_stderr,none | N/A      | |
|            | +-----------------------------+----------+ |
|            | | word_perplexity,none        | 7.363215 | |
|            | +-----------------------------+----------+ |
|            | | word_perplexity_stderr,none | N/A      | |
|            | +-----------------------------+----------+ |
+------------+--------------------------------------------+

AxisWise Dynamic Scaling

+------------+--------------------------------------------+
| Task       | Metrics                                    |
+============+============================================+
| winogrande | +-----------------+----------+             |
|            | | acc,none        | 0.735596 |             |
|            | +-----------------+----------+             |
|            | | acc_stderr,none | 0.012395 |             |
|            | +-----------------+----------+             |
+------------+--------------------------------------------+
| wikitext   | +-----------------------------+----------+ |
|            | | bits_per_byte,none          | 0.538637 | |
|            | +-----------------------------+----------+ |
|            | | bits_per_byte_stderr,none   | N/A      | |
|            | +-----------------------------+----------+ |
|            | | byte_perplexity,none        | 1.452600 | |
|            | +-----------------------------+----------+ |
|            | | byte_perplexity_stderr,none | N/A      | |
|            | +-----------------------------+----------+ |
|            | | word_perplexity,none        | 7.363215 | |
|            | +-----------------------------+----------+ |
|            | | word_perplexity_stderr,none | N/A      | |
|            | +-----------------------------+----------+ |
+------------+--------------------------------------------+

Stack from ghstack (oldest at bottom):