pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

delayed scaling: delete Float8LinearMixin #271

Closed vkuzo closed 3 months ago

vkuzo commented 4 months ago

Stack from ghstack (oldest at bottom):

Summary:

The mixin was originally used to share code with Float8 versions of RowParallelLinear and ColParallelLinear. Since we moved those to DTensor, the mixin is not needed anymore. Removing it to simplify the code in preparation of upcoming delayed scaling improvements.

In addition, making the from_float conversion use meta device to speed it up.

Test Plan:

pytest test/test_base.py -s -x
pytest test/test_compile.py -s -x

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D58396872

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

recreated in https://github.com/pytorch-labs/float8_experimental/pull/276 to get around ghstack weirdness