Closed vkuzo closed 4 months ago
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
This pull request has been merged in pytorch-labs/float8_experimental@c58fb5d6ac768f2213e7f65123cfff779bef9d87.
Stack from ghstack (oldest at bottom):
Summary:
This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable:
LinearMMConfig
toFloat8Tensor
to tie together the threeScaledMMConfig
objects, one per gemmGemmInputRole
toFloat8Tensor
to specify how to pick the right configNote that none of this is user facing, and there is no logic change. Planned follow-ups:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59973551