pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

static scaling support for training #306

Closed vkuzo closed 1 week ago

vkuzo commented 2 weeks ago

Stack from ghstack (oldest at bottom):

Summary:

Some activations such as sigmoid can have a bounded range. This PR adds support for setting a bounded range in training.

Test Plan:

// unit tests
pytest test/test_base.py

// baseline
> python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear
...
 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
category                                              
0_gemm         0.637     0.353       0.555       1.803
1_f8_overhead  0.000     0.175         inf       0.000
2_other        0.224     0.199       0.888       1.126
All            0.861     0.727       0.844       1.184

> python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear --scaling_type_x static
...
 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
category                                              
0_gemm         0.635     0.360       0.566       1.766
1_f8_overhead  0.000     0.182         inf       0.000
2_other        0.224     0.154       0.688       1.454
All            0.859     0.696       0.810       1.234

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo commented 1 week ago

keeping this in the back pocket until it's needed, abandon for now