google / highway

Performance-portable, length-agnostic SIMD with runtime dispatch
Apache License 2.0
3.95k stars 306 forks source link

arm64 lanes #2248

Open xiaozhuai opened 2 weeks ago

xiaozhuai commented 2 weeks ago

Hi there.

constexpr hn::ScalableTag<uint8_t> df;
constexpr hn::CappedTag<uint8_t, 1> d1;
constexpr size_t N = hn::Lanes(df);

On arm64 with neon, the N should be 32, but it's 16 when I debug, why?

BTW, here is my origin code use neon directly, what does its equivalent hwy code look like?

for (; dst_ptr <= dst_end - 8 * 4; dst_ptr += 8 * 4, src_ptr += 8 * 4) {
        // load 8 pixels
        uint8x8x4_t rgba = vld4_u8(src_ptr);

        // multiply the color by alpha, expand to 16-bit
        uint16x8_t r = vmull_u8(rgba.val[0], rgba.val[3]);
        uint16x8_t g = vmull_u8(rgba.val[1], rgba.val[3]);
        uint16x8_t b = vmull_u8(rgba.val[2], rgba.val[3]);

        // (x + 127) / 255 == (x + ((x + 128) >> 8) + 128) >> 8
        // This form is well suited to NEON:
        // vrshrq_n_u16(...,8) gives the inner (x+128)>>8,
        // vraddhn_u16() both the outer add-shift and our conversion back to 8-bit.
        rgba.val[0] = vraddhn_u16(r, vrshrq_n_u16(r, 8));
        rgba.val[1] = vraddhn_u16(g, vrshrq_n_u16(g, 8));
        rgba.val[2] = vraddhn_u16(b, vrshrq_n_u16(b, 8));
}
// process the remaining pixels
for (; dst_ptr < dst_end; dst_ptr += 4, src_ptr += 4) {
    dst_ptr[0] = (src_ptr[0] * src_ptr[3] + 127) / 255;
    dst_ptr[1] = (src_ptr[1] * src_ptr[3] + 127) / 255;
    dst_ptr[2] = (src_ptr[2] * src_ptr[3] + 127) / 255;
}
jan-wassenberg commented 2 weeks ago

Hi, NEON registers are 128 bit, so we would indeed expect N=128/CHAR_BIT / sizeof(uint8_t) = 16, right?

vld4 is LoadInterleaved4, vmull is MulHigh. We don't (yet) have vraddh nor vrshrq_n - those rounding instructions are specific to Arm, not sure how feasible/desirable it is to emulate them elsewhere.

xiaozhuai commented 2 weeks ago

@jan-wassenberg Thanks for reply.

We don't (yet) have vraddh nor vrshrq_n - those rounding instructions are specific to Arm, not sure how feasible/desirable it is to emulate them elsewhere.

Currently, I use NEON_2_SSE.h on non-arm platforms. And I was thinking if I could replace all simd usage in my project with highway. I read the documentation, but I'm still confused. Is there some demo code that can get me started quickly? Thanks a lot!

_NEON2SSE_INLINE uint8x8_t  vraddhn_u16(uint16x8_t a, uint16x8_t b) // VRADDHN.I16 d0,q0,q0
{
    uint8x8_t res64;
    __m128i sum, mask1;
    sum = _mm_add_epi16 (a, b);
    mask1 = _mm_slli_epi16(sum, 8); //shift left then back right to
    mask1 = _mm_srli_epi16(mask1, 15); //get  7-th bit 1 or zero
    sum = _mm_srai_epi16 (sum, 8); //get high half
    sum = _mm_add_epi16 (sum, mask1); //actual rounding
    sum = _mm_packus_epi16 (sum, sum);
    return64(sum);
}
_NEON2SSE_INLINE uint16x8_t vrshrq_n_u16(uint16x8_t a, __constrange(1,16) int b) // VRSHR.S16 q0,q0,#16
{
    __m128i maskb, r;
    maskb =  _mm_slli_epi16(a, (16 - b)); //to get rounding (b-1)th bit
    maskb = _mm_srli_epi16(maskb, 15); //1 or 0
    r = _mm_srli_epi16 (a, b);
    return _mm_add_epi16 (r, maskb); //actual rounding
}
jan-wassenberg commented 2 weeks ago

Sure, check out hwy/examples/* or the Godbolt examples linked in README under "## Examples", is that what you had in mind?

The above implementations look reasonable. You could get a platform-independent Highway version by opening x86_128-inl.h and searching for the intrinsic name (_mm_slli_epi16) to get the equivalent Highway name (ShiftLeft).

xiaozhuai commented 1 week ago

Finally, I came up with these code, and it works well. But I am not sure if it's the best practices. Could you please give me some hint of this?

The following algorithm implements the premultiplication of RGBA8888. The formula is (c * a + 127) / 255

    const hn::ScalableTag<uint16_t> d;
    const hn::CappedTag<uint16_t, 1> d1;
    const size_t N = hn::Lanes(d);
    auto loop = [](const auto d, uint8_t *HWY_RESTRICT dst_ptr, const uint8_t *HWY_RESTRICT src_ptr) HWY_ATTR {
        using DU8 = hn::Rebind<uint8_t, decltype(d)>;
        hn::Vec<DU8> r;
        hn::Vec<DU8> g;
        hn::Vec<DU8> b;
        hn::Vec<DU8> a;
        hn::LoadInterleaved4(DU8(), src_ptr, r, g, b, a);
        auto r16 = hn::PromoteTo(d, r);
        auto g16 = hn::PromoteTo(d, g);
        auto b16 = hn::PromoteTo(d, b);
        auto a16 = hn::PromoteTo(d, a);
        r16 = hn::Div(hn::MulAdd(r16, a16, hn::Set(d, 127)), hn::Set(d, 255));
        g16 = hn::Div(hn::MulAdd(g16, a16, hn::Set(d, 127)), hn::Set(d, 255));
        b16 = hn::Div(hn::MulAdd(b16, a16, hn::Set(d, 127)), hn::Set(d, 255));
        r = hn::DemoteTo(DU8(), r16);
        g = hn::DemoteTo(DU8(), g16);
        b = hn::DemoteTo(DU8(), b16);
        hn::StoreInterleaved4(r, g, b, a, DU8(), dst_ptr);
    };
    for (; dst_ptr + N * 4 <= dst_end; dst_ptr += N * 4, src_ptr += N * 4) {
        loop(d, dst_ptr, src_ptr);
    }
    for (; dst_ptr + 4 <= dst_end; dst_ptr += 4, src_ptr += 4) {
        loop(d1, dst_ptr, src_ptr);
    }

BTW, I'm not sure of which one of the following should I use.

const hn::ScalableTag<uint16_t> d;
const hn::CappedTag<uint16_t, 1> d1;
const hn::ScalableTag<uint8_t> d;
const hn::CappedTag<uint8_t, 1> d1;
jan-wassenberg commented 1 week ago

Hi, Div can be quite expensive. For this frequent use case of multiplying pixels, it is more common to use MulHigh or MulFixedPoint15.

As to which tag: as currently written, Rebind uses the same number of 8-bit elements as 16-bit, so the LoadInterleaved4 is using half vectors. It does not matter whether you have 8-bit or 16-bit passed to ScalableTag, but generally it is safer to use the larger one. Then you can use Repartition instead of Rebind, so that you get twice as many bytes per loop iteration. For promoting you can then use PromoteUpper/LowerTo, then OrderedDemote2To to get an 8-bit vector to Store.

johnplatts commented 1 week ago

We don't (yet) have vraddh nor vrshrq_n - those rounding instructions are specific to Arm, not sure how feasible/desirable it is to emulate them elsewhere.

RVV actually does have vssrl and vssra rounding right shift instructions that are equivalent to NEON/SVE2 SRSHL/SRSHR/URSHL/URSHR (which are wrapped by the NEON vrshl/vrshr intrinsics and SVE2 svrshl/svrshr intrinsics).

RVV/PPC/NEON/SVE2 have instructions for signed integer AverageRound and I32/U32 AverageRound in addition to the U8/U16 AverageRound op that is currently implemented.

johnplatts commented 1 week ago

Here is how RoundingShiftRight op could be implemented for targets other than RVV/NEON/SVE2:

template <int kShiftAmt, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
HWY_API V RoundingShiftRight(V v) {
  using D = DFromV<decltype(v)>;
  using T = TFromD<D>;
  constexpr int kMaxShiftAmt = static_cast<int>(sizeof(T) * 8 - 1);
  static_assert(0 <= kShiftAmt && kShiftAmt <= kMaxShiftAmt,
                "kShiftAmt must be between 0 and sizeof(T) * 8 - 1");

  HWY_IF_CONSTEXPR(kShiftAmt > 0) {
    const D d;
    const RebindToUnsigned<D> du;
    return Add(ShiftRight<kShiftAmt>(v),
               And(BitCast(d, ShiftRight<((kShiftAmt - 1) & kMaxShiftAmt)>(
                                  BitCast(du, v))),
                   Set(d, T{1})));
  }
  else {
    return v;
  }
}

Here is another possible implementation of RoundingRightShift (will need to implement the missing AverageRound ops for I8/I6/I32/I64/U64/U64 to support this implementation for lane types other than U8/U16):

template <int kShiftAmt, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
HWY_API V RoundingShiftRight(V v) {
  using D = DFromV<decltype(v)>;
  using T = TFromD<D>;
  constexpr int kMaxShiftAmt = static_cast<int>(sizeof(T) * 8 - 1);
  static_assert(
    0 <= kShiftAmt && kShiftAmt <= kMaxShiftAmt,
    "kShiftAmt must be between 0 and sizeof(T) * 8 - 1");

  HWY_IF_CONSTEXPR (kShiftAmt > 0) {
    const D d;
    return AverageRound(
      ShiftRight<((kShiftAmt - 1) & kMaxShiftAmt)>(v),
      Zero(d));
  } else {
    return v;
  }
}
jan-wassenberg commented 1 week ago

Thanks for adding the AverageRound in #2249. The second RoundingShiftRight looks just about fast enough on most platforms if you agree this is a useful op to add :)

xiaozhuai commented 1 week ago

Adding RoundingShiftRight would help a lot in my case. Thanks @johnplatts BTW, is it possible to add RoundingAddHalfNarrow and RoundingSubHalfNarrow?

xiaozhuai commented 1 week ago

@jan-wassenberg With your help, I think I got the things right. The performance of the current version is almost the same as the performance I first achieved using neon directly. If we implements RoundingShiftRight later, it can be further improved. And here is the code. Do you have any other suggestions? Thanks again!

void premultiply_u8_hwy(uint8_t *HWY_RESTRICT dst, const uint8_t *HWY_RESTRICT src,  //
                        int width, int height) {
    constexpr int channels = 4;
    auto *dst_ptr = dst;
    const auto *src_ptr = src;
    const auto *const dst_end = dst_ptr + width * height * channels;
    auto premultiply = [](const auto d16, const auto &c, const auto &a) {
        auto tmp0 = hn::Mul(c, a);
        // (x + 127) / 255 == (x + ((x + 128) >> 8) + 128) >> 8
        auto tmp1 = hn::ShiftRight<8>(hn::Add(tmp0, hn::Set(d16, 128)));
        auto tmp2 = hn::Add(tmp0, tmp1);
        return hn::ShiftRight<8>(hn::Add(tmp2, hn::Set(d16, 128)));
        // auto tmp1 = hn::RoundingShiftRight<8>(tmp0);
        // auto tmp2 = hn::Add(tmp0, tmp1);
        // return hn::RoundingShiftRight<8>(tmp2);
    };
    {
        constexpr hn::ScalableTag<uint8_t> d8;
        constexpr hn::Repartition<uint16_t, decltype(d8)> d16;
        constexpr size_t N = hn::Lanes(d8);
        for (; dst_ptr + N * channels <= dst_end; dst_ptr += N * channels, src_ptr += N * channels) {
            hn::Vec<decltype(d8)> r, g, b, a;
            hn::LoadInterleaved4(d8, src_ptr, r, g, b, a);

            auto r16_lower = hn::PromoteLowerTo(d16, r);
            auto g16_lower = hn::PromoteLowerTo(d16, g);
            auto b16_lower = hn::PromoteLowerTo(d16, b);
            auto a16_lower = hn::PromoteLowerTo(d16, a);
            auto r16_upper = hn::PromoteUpperTo(d16, r);
            auto g16_upper = hn::PromoteUpperTo(d16, g);
            auto b16_upper = hn::PromoteUpperTo(d16, b);
            auto a16_upper = hn::PromoteUpperTo(d16, a);

            r16_lower = premultiply(d16, r16_lower, a16_lower);
            g16_lower = premultiply(d16, g16_lower, a16_lower);
            b16_lower = premultiply(d16, b16_lower, a16_lower);
            r16_upper = premultiply(d16, r16_upper, a16_upper);
            g16_upper = premultiply(d16, g16_upper, a16_upper);
            b16_upper = premultiply(d16, b16_upper, a16_upper);

            r = hn::OrderedDemote2To(d8, r16_lower, r16_upper);
            g = hn::OrderedDemote2To(d8, g16_lower, g16_upper);
            b = hn::OrderedDemote2To(d8, b16_lower, b16_upper);
            hn::StoreInterleaved4(r, g, b, a, d8, dst_ptr);
        }
    }
    {
        constexpr hn::CappedTag<uint8_t, 1> d8;
        constexpr hn::Rebind<uint16_t, decltype(d8)> d16;
        for (; dst_ptr + channels <= dst_end; dst_ptr += channels, src_ptr += channels) {
            hn::Vec<decltype(d8)> r, g, b, a;
            hn::LoadInterleaved4(d8, src_ptr, r, g, b, a);

            auto r16 = hn::PromoteTo(d16, r);
            auto g16 = hn::PromoteTo(d16, g);
            auto b16 = hn::PromoteTo(d16, b);
            auto a16 = hn::PromoteTo(d16, a);

            r16 = premultiply(d16, r16, a16);
            g16 = premultiply(d16, g16, a16);
            b16 = premultiply(d16, b16, a16);

            r = hn::DemoteTo(d8, r16);
            g = hn::DemoteTo(d8, g16);
            b = hn::DemoteTo(d8, b16);
            hn::StoreInterleaved4(r, g, b, a, d8, dst_ptr);
        }
    }
}
jan-wassenberg commented 1 week ago

Looks quite good, and @johnplatts has indeed implemented RoundingShift, which is awesome :) I am currently traveling and will review on Tue.

Your code looks good to me, you could consider using LoadN for the last iteration which is likely faster than a scalar loop.