pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

Add more compile compatibility for Float8Tensor ops #285

Closed ani300 closed 3 months ago

drisspg commented 3 months ago

We have been pretty targeted with the ops we support for Float8Tensor, I am curious if you have any concrete usecases for these ops. Could you also add some tests cases. Otherwise thanks for the contributions, I am surprised that the contstructor isnt working correctly would also love a test case there!

ani300 commented 3 months ago

Yes, most of these ops are due to using the Float8Tensor to handle an FP8 kv-cache. The example usage for all of these will be in https://github.com/pytorch-labs/fp468-llm today or tomorrow at the latest. I'll add the tests for both the ops and the constructor. The issue on the constructor was that that it didn't matter what the original dtype was, it always returned fp32

ani300 commented 3 months ago

@drisspg as I'm writing the unit tests, I'm thinking of what a correct copy_ operation looks like: If we try to copy an FP32/FP16/BF16 tensor into an FP8 one, should we do some scaling if the Float8Tensor has it? Or what does the opposite operation look like as well? Say copying an FP8 tensor with scales into an FP32/FP16 one? should we unscale through the FromFloat8Constructor?

drisspg commented 3 months ago

@ani300 Great questions

For a copy from scaled_fp8 to hp_type I think we should unscale and copy into.

For copy from hp to an fp8 tensors, I think the semantics are a little hazier. Do you have a clear need for this operation? Otherwise I would potentially ban this for now. Some options:

I actually recently thought about a related problem when adding copy dispatch to NF4Tensor. This was to enable Subclass -> Subclass copy. The most reasonable semantic I could come up with is to use the high-precision dtype as the intermediary between the conversion: https://github.com/pytorch/ao/pull/45

@vkuzo any strong thoughts on the semantics here?

vkuzo commented 3 months ago

If we try to copy an FP32/FP16/BF16 tensor into an FP8 one, should we do some scaling if the Float8Tensor has it? Or what does the opposite operation look like as well?

IMO:

ani300 commented 3 months ago

Thanks @drisspg and @vkuzo for your comments and opinions! I'll implement the FP8 -> BF16 copy (which is the one I'm using anyways), add the Float8Tensor to Float8Tensor with everything equal (scale, mm_config, etc.), and ban everything else.

ani300 commented 3 months ago

For the failing unit test, I'm waiting on pytorch CI not failing to run at all to land this: https://github.com/pytorch/pytorch/pull/128758

drisspg commented 3 months ago

@ani300 failing CI is becuase we are still using last nights nightly

facebook-github-bot commented 3 months ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

@drisspg merged this pull request in pytorch-labs/float8_experimental@b5a444a3ec5fcd45fe86175256d1ab862c64fcb0.