pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 166 forks source link

Make custom FPx dtype conversion easier to use #354

Closed gau-nernst closed 4 months ago

gau-nernst commented 5 months ago

Referring this https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py

Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.

Case in point, FP6-LLM upstream added support for FP5 E2M2 (https://github.com/usyd-fsalab/fp6_llm). This is what I need to write to support FP32->FP5 E2M2.

# define constants for F32 <-> F5_E2M2
F5_E2M2_MAX = 7.0  # (2 ** (0b11 - 0b01)) * (1 + 0.5 + 0.25)
F5_E2M2_MIN_NORMAL = 1.0  # (2 ** (0b01 - 0b01))
EBITS_F5_E2M2 = 2
MBITS_F5_E2M2 = 2
F5_E2M2_EXP_BIAS = 0b01
F5_E2M2_MAX_INT = (1 << 4) - 1
SIGN_MASK_F5_E2M2 = 1 << 4

MAGIC_ADDER_F5_E2M2 = (1 << (MBITS_F32 - EBITS_F5_E2M2)) - 1

DENORM_F32TOF5_E2M2_EXP = (
    # exp bias conversion between formats
    (F32_EXP_BIAS - F5_E2M2_EXP_BIAS)
    # mantissa length difference between formats
    + (MBITS_F32 - MBITS_F5_E2M2)
    # add one to encoded exponent for denormalized numbers
    + 1
)
DENORM_F32TOF5_E2M2_MASK_INT = DENORM_F32TOF5_E2M2_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF5_E2M2_MASK_FLOAT = struct.unpack("!f", struct.pack("!I", DENORM_F32TOF5_E2M2_MASK_INT))[0]

def f32_to_f5_e2m2_unpacked(x: Tensor):
    return _f32_to_f4_or_f6_unpacked(
        x,
        F5_E2M2_MAX,
        F5_E2M2_MIN_NORMAL,
        DENORM_F32TOF5_E2M2_MASK_FLOAT,
        DENORM_F32TOF5_E2M2_MASK_INT,
        EBITS_F5_E2M2,
        MBITS_F5_E2M2,
        F5_E2M2_EXP_BIAS,
        MAGIC_ADDER_F5_E2M2,
        F5_E2M2_MAX_INT,
        SIGN_MASK_F5_E2M2,
    )

Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function (or cache them somewhere, though I think re-calculating these constants shouldn't take much time).

The other direction (FPx->FP32) is a bit trickier when handling denormal FPx, but should still be possible to make it more generic.

Proposed changes

Tagging @vkuzo and @msaroufim for discussion and opinion.

vkuzo commented 5 months ago

Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.

If it can be useful for other explorations in a more central place that sounds great to me, readability improvements also sound great.

Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function

Yep, that sounds great

Change _f32_to_f4_or_f6_unpacked() and _f4_or_f6_unpacked_to_f32() to _f32_to_fpx_unpacked(x, n_ebits, n_mbits) and _fpx_unpacked_to_f32(x, n_ebits, n_mbits) (packed format is out of scope, should be handled separately for each case)

SGTM

Move non-mx specific stuff from custom_cast.py to an upper level e.g. prototype/fp_cast_utils.py (e.g. functions for packed fp4, custom triton kernels should stay in custom_cast.py)

SGTM