pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

[not for land] enumerate breakages with module hooks + compile #270

Open vkuzo opened 1 month ago

vkuzo commented 1 month ago

Summary:

This PR rewrites Float8DynamicLinear to use module hooks, as we think long term this is more composable with other PyTorch features. For now there is no plan to land this, this is just reproducing / sharing what breaks when we try this today.

Test Plan:

// note: all tests pass without this PR

// eager mode is fine
> pytest -s test/test_base.py | with-proxy gh gist create
https://gist.github.com/vkuzo/aded224af91092c8326becc855b125c9

// compile has some errors in aot_eager backend
> pytest -s test/test_compile.py | with-proxy gh gist create
https://gist.github.com/vkuzo/cab55b11a2c3cee0d1ff94169131b171

// dtensor + float8 has numeric issues
> ./test/test_dtensor.sh | with-proxy gh gist create
https://gist.github.com/vkuzo/d1035200db22f2e3357438824cd3594f

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo commented 1 month ago

cc @bdhirsh