pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

[2/x] clean up casting functions: delayed scaling #346

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

Removes delayed scaling from float8_tensor.py. After this PR, the invariant is that everything in float8_tensor.py requires the scale to be calculated elsewhere. This moves the codebase towards separation of concerns for calculating the scale (via various scaling strategies), separated from creating an instance of Float8Tensor.

Note that stateful delayed scaling is the reason we need this separation.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D60291447

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

This pull request has been merged in pytorch-labs/float8_experimental@ab2b828936412ec40bf45f18335ba62fdef60230.