romeric / fastapprox

Approximate and vectorized versions of common mathematical functions
194 stars 35 forks source link

Faster, more accurate & fault-tolerant log/exp functions #5

Open EvanBalster opened 4 years ago

EvanBalster commented 4 years ago

I've been experimenting with fast approximation functions as an evening hobby lately. This started with a project to generalize the fast-inverse-square-root hack, and I may create another pull request to add a collection of fast roots and inverse roots to fastapprox.

With this PR, I've modified the scalar functions to use fast float-to-int conversion, avoiding int/float casts.

  //fasterlog2abs
  union {float f; uint32_t i;} xv = {x}, lv;
  lv.i = 0x43800000u | (xv.i >> 8u);
  return (lv.f - 382.95695f);

This results in an appreciable improvement in speed and appears to improve accuracy as well (I did a little tuning). The modified functions are also a little more fault tolerant, with log functions acting as log(abs(x)) and exp functions flushing to zero for very small arguments.

These changes should be simple to generalize to SSE but I've decided to submit this PR without those changes. I believe small further improvements in the accuracy of fastlog2 and fastpow2 are possible by experimenting with the coefficients, but in any case all modified functions exhibit a reduction in worst-case error in my tests.

Error and performance statistics on the modified functions, here designated "abs". I tested on an x86_64 MacBook pro. My testbed tries every float within an exponential range rather than sampling randomly.

    Approximate log2(x) with fasterlog2abs
    Error:
        RMS:  0.0236228
        mean: -0.0145537
        min:  -0.0430415 @ 0.180279
        max:  0.0430603 @ 0.125
    Approximate log2(x) with fasterlog2
    Error:
        RMS:  0.0220738
        mean: 8.36081e-06
        min:  -0.0287746 @ 2.88496
        max:  0.0573099 @ 3.99998
    Approximate log2(x) with fastlog2abs
    Error:
        RMS:  4.62902e-05
        mean: -6.39351e-06
        min:  -8.9407e-05 @ 0.226345
        max:  8.79765e-05 @ 5.66174
    Approximate log2(x) with fastlog2
    Error:
        RMS:  8.78776e-05
        mean: -6.90928e-05
        min:  -0.000150681 @ 7.22098
        max:  1.16825e-05 @ 5.67368
    Approximate 2^x with fasterpow2abs
    Error:
        RMS:  0.0204659
        mean: 0.0103297
        min:  -0.0294163 @ 1.04308
        max:  0.0302728 @ 1.48549
    Approximate 2^x with fasterpow2
    Error:
        RMS:  0.0176322
        mean: 0.000165732
        min:  -0.038947 @ 1.05731
        max:  0.0201453 @ 1.50017
    Approximate 2^x with fastpow2abs
    Error:
        RMS:  2.02567e-05
        mean: 7.34409e-07
        min:  -6.03545e-05 @ 1.14546
        max:  6.67494e-05 @ 1.00017
    Approximate 2^x with fastpow2
    Error:
        RMS:  2.84377e-05
        mean: -2.16665e-05
        min:  -6.83582e-05 @ 1.13344
        max:  2.30076e-05 @ 1.99999

   CPU TIME  |  FORMULA    
------------ + ------------
    60345257 | y (baseline for other measurements)
    40318946 | fasterlog2
     4506380 | fasterlog2abs
    36011548 | fastlog2
    31941574 | fastlog2abs
   120095855 | std::log2
    14267362 | fasterpow2abs
    18984349 | fasterpow2
    67782074 | fastpow2abs
    95174888 | fastpow2
    33921997 | std::exp2
------------ + ------------
    54152142 | y (baseline for other measurements)
    19949223 | fasterlog2
    12560614 | fasterlog2abs
    41892165 | fastlog2
    36505168 | fastlog2abs
   125374347 | std::log2
    20324629 | fasterpow2abs
    25769136 | fasterpow2
    73251135 | fastpow2abs
   100127989 | fastpow2
    41127741 | std::exp2
------------ + ------------
    54386750 | y (baseline for other measurements)
    23603119 | fasterlog2
    15133763 | fasterlog2abs
    62677897 | fastlog2
    47765052 | fastlog2abs
   124088143 | std::log2
    21256368 | fasterpow2abs
    25930862 | fasterpow2
    73285208 | fastpow2abs
   103932786 | fastpow2
    39983032 | std::exp2
------------ + ------------
JobLeonard commented 4 years ago

(not a maintainer of this repo)

Nice! By the way, this confused me for a second:

  RMS:  0.0220738
  mean: 8.36081e-06
  min:  -0.0287746 @ 2.88496
  max:  0.0573099 @ 3.99998

Wouldn't it make sense to (also) calculate a mean error value with only absolute values?

EvanBalster commented 4 years ago

Hello —

I'm an audio/DSP programmer... Depending on the application, different error metrics might be more important or have different significance. The optimal constants change based on what type of error you're trying to minimize. Fastapprox minimizes "infinity-norm" error, or max(abs(approx-actual)).

In signal processing, the mean error-value could represent a zero-frequency "DC offset" introduced into the resulting signal, or a small bias in a measurement. The RMS could predict harmonic interference. Even when optimizing against absolute max error, it's useful to know the minimum as well, in case additional functions are applied that magnify one more than the other.

Lstly, when searching for optimal constants, evaluating min and max together permits a nifty heuristic: when the function's infinity-norm error is close to optimal, the minimum and maximum are very similar in magnitude.

EvanBalster commented 4 years ago

[Tagging @pmineiro in case he has any thoughts about this PR.]

MathGeniusJodie commented 3 years ago

it's kinda easier to reason about the exponent like this don't you think? log2exponent.i = 0x4b000000u | (input.i >> 23u); log2exponent.f -= 8388735.;

MathGeniusJodie commented 3 years ago

also the log2 curve can be approximated really well with a polynomial, no need for division

MathGeniusJodie commented 3 years ago

what I did for exp2 if anyone is interested, polynomial curve, exact fractional part

// ./lolremez --float -d 4 -r "1:2" "exp2(x-1)"
// rms error 0.000003
// max error 0.000004
float exp2xminus1(float x){
    float u = 1.3697664e-2f;
    u = u * x + -3.1002997e-3f;
    u = u * x + 1.6875336e-1f;
    u = u * x + 3.0996965e-1f;
    return u * x + 5.1068333e-1f;
}

float fastexp2 (float p){
  //exp2(floor(x))*exp2(fract(x)) == exp2(x)
  // 383 is 2^8+127
  union {float f; uint32_t i;} u = {p + 383.f}, fract;
  // shove the mantissa bits into the exponent  
  u.i<<=8u;

  //the remaining mantissa bits are the fract of p
  fract.i = 0x3f800000u | (u.i & 0x007fff00u);
  //optional to fix precision loss
  fract.f += p-((p+383.f)-383.f);
  //fract.f-=1.;
  //only take the exponent
  u.i = u.i & 0x7F800000;
  return u.f * exp2xminus1(fract.f);
}
ghuls commented 2 years ago

A log1p implementation would be nice to have too.