pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

[1/x] clean up casting functions #345

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

This is a start of a cleanup of private casting functions in preparation for rowwise scaling. In this PR:

  1. create float8_scaling_utils.py to unify functions which take a high precision tensor and return a float8 tensor, taking care of scaling
  2. delete Float8Tensor.to_float8 and move callsites to ToFloat8ConstrFunc, since the two functions do the same thing

The end result is a slightly cleaner state, future PRs will do more cleanups.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D60291448

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

This pull request has been merged in pytorch-labs/float8_experimental@214da1f88c986f97444f97c5aa09b2e45bef62ee.