ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 83 forks source link

[BF16] GPU Implementation #3519

Open richagadgil opened 4 days ago

richagadgil commented 4 days ago

Idea:

Cast FP32/FP16 to BF16.

Casting will be different based on type:

May involve very slight loss in precision for both.

Workflow:

Follow similar workflow as FP8.

pfultz2 commented 4 days ago

So I started work on a generic_float class so we can specify any float type:

template<unsigned int N>
constexpr unsigned int all_ones() noexcept
{
    return (1 << N) - 1;
}

struct float32_parts 
{
    unsigned int mantissa : 23;
    unsigned int exponent : 8;
    unsigned int sign : 1;

    static constexpr unsigned int mantissa_width()
    {
        return 23;
    }

    static constexpr unsigned int max_exponent()
    {
        return all_ones<8>();
    }

    static constexpr int exponent_bias()
    {
        return all_ones<7>();
    }

    constexpr float to_float() const noexcept
    {
        return bit_cast<float>(*this);
    }
};

constexpr float32_parts get_parts(float f)
{
    return bit_cast<float32_parts>(f);
}

template<unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct generic_float
{
    unsigned int mantissa : MantissaSize;
    unsigned int exponent : ExponentSize;
    unsigned int sign : 1;

    static constexpr int exponent_bias()
    {
        return all_ones<ExponentSize - 1>();
    }

    explicit generic_float(float f = 0.0) noexcept
    {
        from_float(get_parts(f));
    }

    constexpr float to_float() const noexcept
    {
        float32_parts f{};
        f.sign = sign;
        f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
        if(exponent == all_ones<ExponentSize>())
        {
            f.exponent = float32_parts::max_exponent();
        }
        else
        {
            constexpr const auto diff = float32_parts::exponent_bias() - exponent_bias();
            f.exponent = exponent + diff;
        }
        return f.to_float();
    }

    constexpr void from_float(float32_parts f) noexcept
    {
        sign  = f.sign;
        mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);

        if(f.exponent == 0)
        {
            exponent = 0;
        }
        else if(f.exponent == float32_parts::max_exponent())
        {
            exponent = all_ones<ExponentSize>();
        }
        else
        {
            constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
            auto e = int(f.exponent) - diff;
            if(e >= all_ones<ExponentSize>())
            {
                exponent = all_ones<ExponentSize>();
                mantissa = 0;
            }
            else if(e < 0)
            {
                exponent = 0;
                mantissa = 0;
            }
            else
            {
                exponent = f.exponent - diff;
            }
        }

        exponent = std::min(f.exponent, all_ones<ExponentSize>());
    }

    constexpr bool is_normal() const noexcept
    {
        return exponent != all_ones<ExponentSize>() and exponent != 0;
    }

    constexpr bool is_inf() const noexcept
    {
        return exponent == all_ones<ExponentSize>() and mantissa == 0;
    }

    constexpr bool is_nan() const noexcept
    {
        return exponent == all_ones<ExponentSize>() and mantissa != 0;
    }

    constexpr bool is_finite() const noexcept
    {
        return exponent != all_ones<ExponentSize>();
    }

    constexpr operator float() const noexcept
    {
        return this->to_float();
    }

    static constexpr generic_float infinity()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        return x;
    }

    static constexpr generic_float snan()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        x.mantissa = 1 << (MantissaSize - 2);
        return x;
    }

    static constexpr generic_float qnan()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        x.mantissa = 1 << (MantissaSize - 1);
        return x;
    }

    static constexpr generic_float min()
    {
        generic_float x{};
        x.exponent = 1;
        x.mantissa = 0;
        return x;
    }

    static constexpr generic_float denorm_min()
    {
        generic_float x{};
        x.exponent = 0;
        x.mantissa = 1;
        x.sign = 0;
        return x;
    }

    static constexpr generic_float lowest()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>() - 1;
        x.mantissa = all_ones<MantissaSize>();
        x.sign = 1;
        return x;
    }

    static constexpr generic_float max()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>() - 1;
        x.mantissa = all_ones<MantissaSize>();
        x.sign = 0;
        return x;
    }

    static constexpr generic_float epsilon()
    {
        generic_float x{1.0};
        x.mantissa++;
        return generic_float{x.to_float() - 1.0f};
    }
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
    constexpr generic_float& operator op(const generic_float& rhs) \
    { \
        float self = *this; \
        float frhs = rhs; \
        self op frhs; \
        *this = generic_float(self); \
        return *this; \
    }
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
    friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
    { \
        return generic_float(float(x) op float(y)); \
    }
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<=)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>=)

    friend constexpr generic_float operator==(const generic_float& x, const generic_float& y)
    {
        if (not x.is_finite() or not y.is_finite())
            return false;
        return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign);
    }

    friend constexpr generic_float operator!=(const generic_float& x, const generic_float& y)
    {
        return not(x == y);
    }
};

I maybe bias, but I do find this much more readable than the float8 code, I have done some initial testing with fp32 type:

using fp32 = generic_float<23, 8>;

#define CHECK_FLOAT(x, y) \
    CHECK(bit_equal(x, y)); \
    CHECK(bit_equal(x, y.to_float())); \
    CHECK(bit_equal(fp32{x}, y)); \
    CHECK(bit_equal(fp32{x}.to_float(), y.to_float()))

TEST_CASE(fp32_values)
{
    CHECK_FLOAT(1.0f, fp32{1.0f});
    CHECK_FLOAT(-1.0f, fp32{-1.0f});
    CHECK_FLOAT(std::numeric_limits<float>::min(), fp32::min());
    CHECK_FLOAT(std::numeric_limits<float>::lowest(), fp32::lowest());
    CHECK_FLOAT(std::numeric_limits<float>::max(), fp32::max());
    CHECK_FLOAT(std::numeric_limits<float>::epsilon(), fp32::epsilon());
    CHECK_FLOAT(std::numeric_limits<float>::infinity(), fp32::infinity());
    CHECK_FLOAT(std::numeric_limits<float>::quiet_NaN(), fp32::qnan());
    CHECK_FLOAT(std::numeric_limits<float>::signaling_NaN(), fp32::snan());
    CHECK_FLOAT(std::numeric_limits<float>::denorm_min(), fp32::denorm_min());
}

Although this doesnt test the truncation code. Specializations of std::numeric_limits need to be added. The flags parameter also need to be added at some point to handle fp8 types, but that shouldn't be a blocker for BF16.

It would be good to start by replacing our current half type with the generic_float since thats already implemented with initial tests already. We need to create a test suite similar to the fp8 test suite, but we probably cant create 64k arrays so we should probably create samples of say 1k values to test fp16 with.

Then we can easily add BF16:

I split like this so it should allow smaller PRs that should make it easier to review and merge. So it would be 3 PRs for above and one more PR for the half., so 4 PRs in total.