pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

delayed scaling: delete Float8LinearMixin #276

Closed vkuzo closed 1 month ago

vkuzo commented 1 month ago

Stack from ghstack (oldest at bottom):

Summary:

The mixin was originally used to share code with Float8 versions of RowParallelLinear and ColParallelLinear. Since we moved those to DTensor, the mixin is not needed anymore. Removing it to simplify the code in preparation of upcoming delayed scaling improvements.

In addition, making the from_float conversion use meta device to speed it up.

Test Plan:

pytest test/test_base.py -s -x
pytest test/test_compile.py -s -x

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D58396926

vkuzo commented 1 month ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 month ago

This pull request has been merged in pytorch-labs/float8_experimental@323fb489304bcbf5c2521af6bf948ecc35d84bc5.