Closed gau-nernst closed 4 months ago
Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.
If it can be useful for other explorations in a more central place that sounds great to me, readability improvements also sound great.
Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function
Yep, that sounds great
Change _f32_to_f4_or_f6_unpacked() and _f4_or_f6_unpacked_to_f32() to _f32_to_fpx_unpacked(x, n_ebits, n_mbits) and _fpx_unpacked_to_f32(x, n_ebits, n_mbits) (packed format is out of scope, should be handled separately for each case)
SGTM
Move non-mx specific stuff from custom_cast.py to an upper level e.g. prototype/fp_cast_utils.py (e.g. functions for packed fp4, custom triton kernels should stay in custom_cast.py)
SGTM
Referring this https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py
Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.
Case in point, FP6-LLM upstream added support for FP5 E2M2 (https://github.com/usyd-fsalab/fp6_llm). This is what I need to write to support FP32->FP5 E2M2.
Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function (or cache them somewhere, though I think re-calculating these constants shouldn't take much time).
The other direction (FPx->FP32) is a bit trickier when handling denormal FPx, but should still be possible to make it more generic.
Proposed changes
_f32_to_f4_or_f6_unpacked()
and_f4_or_f6_unpacked_to_f32()
to_f32_to_fpx_unpacked(x, n_ebits, n_mbits)
and_fpx_unpacked_to_f32(x, n_ebits, n_mbits)
(packed format is out of scope, should be handled separately for each case)custom_cast.py
to an upper level e.g.prototype/fp_cast_utils.py
(e.g. functions for packed fp4, custom triton kernels should stay incustom_cast.py
)Tagging @vkuzo and @msaroufim for discussion and opinion.