pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

bring back torch.autograd.Function for float8 matmul #344

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

This is a redo of https://github.com/pytorch-labs/float8_experimental/pull/316

With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A torch.autograd.Function override is the cleanest way to do that, and in 2024 this now works with torch.compile.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D60291446

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

This pull request has been merged in pytorch-labs/float8_experimental@345b3a56ff7f1de7061f60b844cff11bfdeb65f5.