pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

[2/x]: fix numerics integration test and test delayed vs dynamic #291

Closed vkuzo closed 3 months ago

vkuzo commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

  1. the SAM test wasn't easy to use because it had real weights and hence required real data for useful testing, which is not convenient from an integration test. Switched to LLaMa FFN with random weights, and made all the thresholds tight to actually check numerics are close.
  2. extended numerics test to check all combinations of delayed vs dynamic
  3. to be able to do (2), extended the module swap utility to configure delayed vs dynamic on a model level, for now without an option to customize further

Test Plan:

pytest test/test_numerics_integration.py -s -x
./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59305796

vkuzo commented 3 months ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

This pull request has been merged in pytorch-labs/float8_experimental@1e71def289645a005418f5df6eaad0984ece1259.