pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

[2/x] clean up casting functions: delayed scaling #343

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

Removes delayed scaling from float8_tensor.py. After this PR, the invariant is that everything in float8_tensor.py requires the scale to be calculated elsewhere. This moves the codebase towards separation of concerns for calculating the scale (via various scaling strategies), separated from creating an instance of Float8Tensor.

Note that stateful delayed scaling is the reason we need this separation.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo commented 3 months ago

started a new PR due to ghstack error