Closed vkuzo closed 3 months ago
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
This pull request has been merged in pytorch-labs/float8_experimental@345b3a56ff7f1de7061f60b844cff11bfdeb65f5.
Stack from ghstack (oldest at bottom):
351
350
349
348
347
346
345
Summary:
This is a redo of https://github.com/pytorch-labs/float8_experimental/pull/316
With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A
torch.autograd.Function
override is the cleanest way to do that, and in 2024 this now works withtorch.compile
.Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D60291446