pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

Updates with new scaled-mm api #284

Closed drisspg closed 3 months ago

drisspg commented 3 months ago

Summary

This updates the calls to _scaled_mm to the new signature from this PR: https://github.com/pytorch/pytorch/pull/128683

This is needed to unblock inductor work on scaled_mm.

❯ ./test/test_everything.sh
    .
    .
    .

test/test_fsdp2/test_fsdp2_eager.py .......                                   [100%]

================================ 7 passed in 27.66s =================================
all tests successful
drisspg commented 3 months ago

We have been pretty fast and loose with versioning. The only required runtime dependency we have is torch and we dont specify a version, but we do make explicit note of the need for nightly: https://github.com/pytorch-labs/float8_experimental?tab=readme-ov-file#installation

I could(and probably should) update the dependency we have to specify we require newer versions of pytorch but still not perfect since, nightly doesnt follow semantic versioning

facebook-github-bot commented 3 months ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

@drisspg merged this pull request in pytorch-labs/float8_experimental@edae9a3e4fc3e1ae8d8cfe02738532c4067b3f4a.