pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

make all 3 gemms in Float8Linear support configurability, not user facing #315

Closed vkuzo closed 4 months ago

vkuzo commented 4 months ago

Stack from ghstack (oldest at bottom):

Summary:

This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable:

  1. add LinearMMConfig to Float8Tensor to tie together the three ScaledMMConfig objects, one per gemm
  2. add GemmInputRole to Float8Tensor to specify how to pick the right config
  3. plumb all of these throughout the codebase

Note that none of this is user facing, and there is no logic change. Planned follow-ups:

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59973551

vkuzo commented 4 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 4 months ago

This pull request has been merged in pytorch-labs/float8_experimental@c58fb5d6ac768f2213e7f65123cfff779bef9d87.