WebAssembly / design

WebAssembly Design Documents
http://webassembly.org
Apache License 2.0
11.4k stars 696 forks source link

[post-MVP proposal] FMA instruction #1391

Open munrocket opened 3 years ago

munrocket commented 3 years ago

Motivation

Fused multiply–add (FMA) is a floating-point operation performed in one step, with a single rounding. FMA can speed up and improve the accuracy of many computations: dot product, matrix multiplication, convolutions and artificial neural networks, polynomial evaluation, Newton's method, Kahan summation, Veltkamp-Dekker algorithm. This instruction exist in languages like: C / C++, Rust, C#, Go, Java, Julia, Swift, OpenGL 4+.

Problem and existing solutions

In WASM there is no way to get speed improvement from this widely supported instruction. There are two way how to implement fallback if hardware is not support it: correctly rounded but slow and simple combination that fast but accumulate error.

For example OpenCL has both. In C/C++ you have to implement fma fallback by yourself but in Rust/Go/C#/Java/Julia fma implemented with correctly rounded fallback. It doesn't make much difference how it will be implemented because you always can detect fma feature in runtime initialisation with special floating point expression and implement conditional fallback as you wish in your software.

If at least first two basic instructions will be implemented it will be great step forward, because right now software that focused on precision need to implement correct fma fallback with more than 23 FLOPs instead of one. This can be finance application or arbitrary precision libraries or space flight simulator with orbital dynamic.

Proposed instructions

(f64.fma $x $y $z) ;; equivalent to fma(x, y, z) in C99
(f32.fma $x $y $z) ;; equivalent to fmaf(x, y, z) in C99

Usually languages compiles fma(x,y,-z) into fused multiply-subtract under the hood. Since .wasm is compilation target looks like all instruction set can be implemented.

(f64.fms $x $y $z) ;; RN(x * y - z) = fma(x,y,-z)
(f32.fms $x $y $z) ;;
(f64.fnma $x $y $z) ;; RN(-x * y + z)
(f32.fnma $x $y $z) ;;
(f64.fnms $x $y $z) ;; RN(-x * y - z)
(f32.fnms $x $y $z) ;;

But not everyone doing all of them because result with negation the same.

Implementation

Here a draft how to implement software fallback based on relatively new Boldo-Melquiond paper.

#include <math.h> //fma, FP_FAST_FMA

static inline double two_prod(const double x, const double y, double &err) {
  double splitter = 0x8000001p0; //2^27+1
  //float splitter = 0x2001p0; //2^13+1

  double t = splitter * a;
  double ah = t + (a - t);
  double al = a - ah;

  t = splitter * b;
  double bh = t + (b - t);
  double bl = b - bh;
  t = a * b;

  err = ((ah * bh - t) + ah * bl + al * bh) + al * bl;
  return t;
}

static inline double two_sum(const double a, const double b, double &err){
  double s = a + b;
  double a1  = s - b;

  err = (a - a1) + (b - (s - a1));
  return s;
}

static inline double fma_correct_fallback(const double x, const double y, const double z) {
  // Check overflows
  // Veltkamp-Dekker multiplication: x * y -> (mul, err)
  // Moller-Knuth summation: mul + z -> (sum, err2)
  // Boldo-Melquiond ternary summation: sum + err + err2 -> fma
}

static inline double f64_fma(const double x, const double y, const double z) {
  #ifdef FP_FAST_FMA
    return fma(x, y, z);
  #else
    return fma_correct_fallback(x, y, z);
  #endif
}

But chromium and apple already had some own implementation.

tlively commented 3 years ago

Related: https://github.com/WebAssembly/simd/issues/10 and https://github.com/WebAssembly/simd/pull/79. We'll be looking at how we can best standardize FMA instructions after the current SIMD proposal is finished.

arunetm commented 3 years ago

Thanks for proposing this. FMAs are super useful, especially in AI/ML apps.

This proposal can also pave way for a new class of Scalar & SIMD instructions, ie expanding to sqrt_approximations, reci_sqrt_approximations, float2int conversions in addition to FMA. We have the option to consider introducing a post-MVP "fast-math" option for Wasm Scalar and SIMD including this class of operations that might have similar feature/platform check requirements.

I have been looking into a few options, we could potentially introduce fast-math option for Wasm in general with a dependency on @tlively's conditional-section/feature detection proposals. Scalar and SIMD FMA's and the other ops listed above can fit in as the fast-math ops. Does this sound to be a reasonable approach? Happy to help with this effort and extending this to other fast-math operations.

mu578 commented 3 years ago

Hello, that's a bit more subtile or if I may even trivial, than what the OP claimed and proposed.


template <class T>
T fma(const T& x, const T& y, const T& z)
// long double unsupported except if it is an alias to double, or float80; anyways, forget about it, off-topic.
{
    const T cs = TwoProduct<T>::split_value; // BTW 2^12 + 1 on float -> IEEE binary32. assuming radix-2...

    T s = z, w = T(0), h, q, r, x1, x2, y1, y2;

    // can be a loop where x and y are values at index and w, s are accumulators.
    {
//!# TwoProduct(x,y,h,r).
    q  = x;
//!# split x into x1,x2.
    r  = cs * q;
    x2 = r - q;
    x1 = r - x2;
    x2 = q - x1;
    r  = y;
//!# h=x*y.
    h  = q * r;
//!# split y into y1,y2.
    q  = cs * r;
    y2 = q - r;
    y1 = q - y2;
    y2 = r - y1;
//!# r=x2*y2-(((h-x1*y1) - x2*y1) - x1*y2
    q  = x1 * y1;
    q  = h - q;
    y1 = y1 * x2;
    q  = q - y1;
    x1 = x1 * y2;
    q  = q - x1;
    x2 = x2 * y2;
    r  = x2 - q;
//!# (w,q)=TwoSum(w,h).
    x1 = w + h;
    x2 = x1 - w;
    y1 = x1 - x2;
    y2 = h - x2;
    q  = w - y1;
    q  = q + y2;
    w  = x1;
//!# s=s+(q+r).
    q = q + r;
    s = s + q;
    }

    return w + s;
}

References:

Accurate dot product sum(x[i] * y[i], i=0...n-1) of two vectors, Accurate Sum and Dot Product, Takeshi Ogita,
Siegfried M. Rump and Shin'ichi Oishi 2005, published in SIAM Journal on Scientific Computing (SISC),
26(6):1955-1988, 2005.

Knuth and Møller, 1998.
Dekker, 1971.

Conclusion: use available hardware FMADD or similar intrinsics, to my recollection on ARM only M5 series have a full support. That's good for educational purposes, performance and accuracy wise that's a dead-end.

munrocket commented 3 years ago

@moe123 trivial code will fail on overflow test

double expected = fma(DBL_MAX, 2, -DBL_MAX);
double actual = trivial_fallback(DBL_MAX, 2, -DBL_MAX);

This is useful (and subtile) if you know that there will be no overflow in code but not suitable for general fallback implementation.

mvduin commented 12 months ago

to my recollection on ARM only M5 series have a full support. That's good for educational purposes, performance and accuracy wise that's a dead-end.

Pretty much every ARM cores that is powerful enough to run a web browser has proper fused multiply-add. Only the first two ARMv7-A cores (Cortex-A8 and -A9) lack it, and all of the ARMv7-R cores with FPU lack it (R4F, R5F, R7F, R8F). Every core since then that has floating-point support includes fused-multiply-add instructions, even including Cortex-M4F microcontrollers.

mitschabaude commented 3 weeks ago

We'll be looking at how we can best standardize FMA instructions after the current SIMD proposal is finished.

Relaxed SIMD, with f64x2.relaxed_fma and similar instructions, was accepted a while ago and has been released in some runtimes.

So, right now we have SIMD versions of FMA but no scalar version i.e. f64.fma etc as the OP suggested.

Is it on the roadmap to add scalar versions as well? Or should a new proposal be created for that? @tlively tagging you as you may know (or may recommend a process)

tlively commented 3 weeks ago

There is not currently a proposal that includes a scalar FMA instruction, so we would need a new proposal for it. We have some general guidance on the process for this here. Specific interesting questions to answer for a potential FMA instruction would be: