pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

[wip] add axiswise granularity to Float8Tensor #352

Open vkuzo opened 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

This PR adds the axiswise scaling granularity to Float8Tensor and ensures that basic ops like transpose and torch._scaled_mm work as expected.

A future PR will add integration with Float8Linear.

Test Plan:

TODO

Reviewers:

Subscribers:

Tasks:

Tags: