pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

[4/x] add tests for DTensor TP/SP + Float8Linear #294

Closed vkuzo closed 2 weeks ago

vkuzo commented 2 weeks ago

Stack from ghstack (oldest at bottom):

Summary:

Makes the DTensor TP/SP tests also test Float8Linear with all scaling types configured to be dynamic.

We can add support for delayed scaling with float8 all-gather for x and dL_dY in a future PR, as needed.

Test Plan:

./test/test_dtensor.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59305797

vkuzo commented 2 weeks ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 2 weeks ago

This pull request has been merged in pytorch-labs/float8_experimental@3ec96650001126283002cc83595fdbf9c605090d.