halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

Fast atan and atan2 functions. #8388

Open mcourteaux opened 3 months ago

mcourteaux commented 3 months ago

Addresses #8243. Uses a polynomial approximation with odd powers: this way, it's immediately symmetrical around 0. Coefficients are optimized using my script which does iterative weight-adjusted least-squared-error (also included in PR; see below).

Added API

/**
 * Struct that allows the user to specify several requirements for functions
 * that are approximated by polynomial expansions. These polynomials can be
 * optimized for four different metrics: Mean Squared Error, Maximum Absolute Error,
 * Maximum Units in Last Place (ULP) Error, or a 50%/50% blend of MAE and MULPE.
 *
 * Orthogonally to the optimization objective, these polynomials can vary
 * in degree. Higher degree polynomials will give more precise results.
 * Note that instead of specifying the degree, the number of terms is used instead.
 * E.g., even symmetric functions may be implemented using only even powers, for which
 * A number of terms of 4 would actually mean that terms in [1, x^2, x^4, x^6] are used,
 * which is degree 6.
 *
 * Additionally, if you don't care about number of terms in the polynomial
 * and you do care about the maximal absolute error the approximation may have
 * over the domain, you may specify values and the implementation
 * will decide the appropriate polynomial degree that achieves this precision.
 */
struct ApproximationPrecision {
    enum OptimizationObjective {
        MSE, //< Mean Squared Error Optimized.
        MAE, //< Optimized for Max Absolute Error.
        MULPE, //< Optimized for Max ULP Error. ULP is "Units in Last Place", measured in IEEE 32-bit floats.
        MULPE_MAE, //< Optimized for simultaneously Max ULP Error, and Max Absolute Error, each with a weight of 50%.
    } optimized_for;
    int constraint_min_poly_terms{0}; //< Number of terms in polynomial (zero for no constraint).
    float constraint_max_absolute_error{0.0f}; //< Max absolute error (zero for no constraint).
};

/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
 *
 * Desired precision can be specified as either a maximum absolute error (MAE) or
 * the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
 * are optimized for either:
 *  - MSE (Mean Squared Error)
 *  - MAE (Maximum Absolute Error)
 *  - MULPE (Maximum Units in Last Place Error).
 *
 * The default (Max ULP Error Polynomial of 6 terms) has a MAE of 3.53e-6.
 * For more info on the available approximations and their precisions, see the table in ApproximationTables.cpp.
 *
 * Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
 * Note: Poly8 is only useful to increase precision for atan, and not for atan2.
 * Note: The performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
 */
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = {ApproximationPrecision::MULPE, 6});
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = {ApproximationPrecision::MULPE, 6});
// @}

I designed this new ApproximationPrecision such that it can be used for other vectorizable functions at a later point as well, such as for fast_sin and fast_cos if we want that at some point. Note that I chose for MAE_1e_5 style of notation, instead of 5Decimals because 5 decimals suggests that there will be 5 decimals correct, which is technically less correct than saying that the maximal absolute error will be below 1e-5.

Performance difference:

Linux/CPU (with precision MAE_1e_5):

                  atan: 7.427325 ns per atan
 fast_atan (MAE 1e-02): 0.604592 ns per atan (91.9% faster)  [per invokation: 2.535843 ms]
 fast_atan (MAE 1e-03): 0.695281 ns per atan (90.6% faster)  [per invokation: 2.916222 ms]
 fast_atan (MAE 1e-04): 0.787722 ns per atan (89.4% faster)  [per invokation: 3.303945 ms]
 fast_atan (MAE 1e-05): 0.863543 ns per atan (88.4% faster)  [per invokation: 3.621961 ms]
 fast_atan (MAE 1e-06): 0.951112 ns per atan (87.2% faster)  [per invokation: 3.989254 ms]

                  atan2: 13.759876 ns per atan2
 fast_atan2 (MAE 1e-02): 1.052900 ns per atan2 (92.3% faster)  [per invokation: 4.416183 ms]
 fast_atan2 (MAE 1e-03): 1.124720 ns per atan2 (91.8% faster)  [per invokation: 4.717417 ms]
 fast_atan2 (MAE 1e-04): 1.245389 ns per atan2 (90.9% faster)  [per invokation: 5.223540 ms]
 fast_atan2 (MAE 1e-05): 1.304229 ns per atan2 (90.5% faster)  [per invokation: 5.470334 ms]
 fast_atan2 (MAE 1e-06): 1.407788 ns per atan2 (89.8% faster)  [per invokation: 5.904690 ms]
Success!

On Linux/CUDA, it's slightly faster than the default LLVM implementation (there is no atan instruction in PTX):

                  atan: 0.012694 ns per atan
 fast_atan (MAE 1e-02): 0.008084 ns per atan (36.3% faster)  [per invokation: 0.542537 ms]
 fast_atan (MAE 1e-03): 0.008257 ns per atan (35.0% faster)  [per invokation: 0.554145 ms]
 fast_atan (MAE 1e-04): 0.008580 ns per atan (32.4% faster)  [per invokation: 0.575821 ms]
 fast_atan (MAE 1e-05): 0.009693 ns per atan (23.6% faster)  [per invokation: 0.650511 ms]
 fast_atan (MAE 1e-06): 0.009996 ns per atan (21.3% faster)  [per invokation: 0.670806 ms]

                  atan2: 0.016339 ns per atan2
 fast_atan2 (MAE 1e-02): 0.010460 ns per atan2 (36.0% faster)  [per invokation: 0.701942 ms]
 fast_atan2 (MAE 1e-03): 0.010887 ns per atan2 (33.4% faster)  [per invokation: 0.730619 ms]
 fast_atan2 (MAE 1e-04): 0.011134 ns per atan2 (31.9% faster)  [per invokation: 0.747207 ms]
 fast_atan2 (MAE 1e-05): 0.011699 ns per atan2 (28.4% faster)  [per invokation: 0.785120 ms]
 fast_atan2 (MAE 1e-06): 0.012122 ns per atan2 (25.8% faster)  [per invokation: 0.813505 ms]
Success!

On Linux/OpenCL, it is also slightly faster:

                  atan: 0.012427 ns per atan
 fast_atan (MAE 1e-02): 0.008740 ns per atan (29.7% faster)  [per invokation: 0.586513 ms]
 fast_atan (MAE 1e-03): 0.008920 ns per atan (28.2% faster)  [per invokation: 0.598603 ms]
 fast_atan (MAE 1e-04): 0.009326 ns per atan (25.0% faster)  [per invokation: 0.625840 ms]
 fast_atan (MAE 1e-05): 0.010362 ns per atan (16.6% faster)  [per invokation: 0.695404 ms]
 fast_atan (MAE 1e-06): 0.011196 ns per atan ( 9.9% faster)  [per invokation: 0.751366 ms]

                  atan2: 0.016028 ns per atan2
 fast_atan2 (MAE 1e-02): 0.011978 ns per atan2 (25.3% faster)  [per invokation: 0.803816 ms]
 fast_atan2 (MAE 1e-03): 0.011715 ns per atan2 (26.9% faster)  [per invokation: 0.786199 ms]
 fast_atan2 (MAE 1e-04): 0.011774 ns per atan2 (26.5% faster)  [per invokation: 0.790166 ms]
 fast_atan2 (MAE 1e-05): 0.012266 ns per atan2 (23.5% faster)  [per invokation: 0.823142 ms]
 fast_atan2 (MAE 1e-06): 0.012728 ns per atan2 (20.6% faster)  [per invokation: 0.854140 ms]
Success!

Precision tests:

Testing for precision 1.000000e-02...
    Testing fast_atan() correctness...  Passed: max abs error: 4.94057e-03
    Testing fast_atan2() correctness...  Passed: max abs error: 4.99773e-03

Testing for precision 1.000000e-03...
    Testing fast_atan() correctness...  Passed: max abs error: 6.07625e-04
    Testing fast_atan2() correctness...  Passed: max abs error: 6.13213e-04

Testing for precision 1.000000e-04...
    Testing fast_atan() correctness...  Passed: max abs error: 8.12709e-05
    Testing fast_atan2() correctness...  Passed: max abs error: 8.20160e-05

Testing for precision 1.000000e-05...
    Testing fast_atan() correctness...  Passed: max abs error: 1.69873e-06
    Testing fast_atan2() correctness...  Passed: max abs error: 1.90735e-06

Testing for precision 1.000000e-06...
    Testing fast_atan() correctness...  Passed: max abs error: 2.98023e-07
    Testing fast_atan2() correctness...  Passed: max abs error: 4.76837e-07
Success!

Optimizer

This PR includes a Python optimization script to find the coefficients of the polynomials:

atan_poly5_optimization

While I didn't do anything very scientific or looked at research papers, I get a hunch that the results from this script are really good (and may actually converge to optimal).

If my optimization makes sense, then I have some funny observation: I get different coefficients for all of the fast approximations we have. See below.

Better coefficients for exp()?

My result:

// Coefficients with max error: 1.0835e-07
const float c_0(9.999998916957e-01f);
const float c_1(1.000010959810e+00f);
const float c_2(4.998191326645e-01f);
const float c_3(1.677545067148e-01f);
const float c_4(3.874100973369e-02f);
const float c_5(1.185256835401e-02f);

versus current Halide code:

https://github.com/halide/Halide/blob/3cdeb5398fb87be699fa830f843ca5d05fe6b983/src/IROperator.cpp#L1432-L1439

Better coefficients for sin()?

// Coefficients with max error: 1.3500e-11
const float c_1(9.999999998902e-01f);
const float c_3(-1.666666654172e-01f);
const float c_5(8.333329271330e-03f);
const float c_7(-1.984070354590e-04f);
const float c_9(2.751888510663e-06f);
const float c_11(-2.379517255457e-08f);

Notice that my optimization gives maximal error of 1.35e-11, instead of the promised 1e-5, with degree 6.

Versus:

https://github.com/halide/Halide/blob/3cdeb5398fb87be699fa830f843ca5d05fe6b983/src/IROperator.cpp#L1390-L1394

If this is true (I don't see a reason why it wouldn't), that would mean we can remove a few terms to get faster version that still provides the promised precision.

Better coefficients for cos()?

// Coefficients with max error: 2.2274e-10
const float c_0(9.999999997814e-01f);
const float c_2(-4.999999936010e-01f);
const float c_4(4.166663631608e-02f);
const float c_6(-1.388836211466e-03f);
const float c_8(2.476019687789e-05f);
const float c_10(-2.605210837614e-07f);

versus:

https://github.com/halide/Halide/blob/3cdeb5398fb87be699fa830f843ca5d05fe6b983/src/IROperator.cpp#L1396-L1400

Better coefficients for log()?

// Coefficients with max error: 2.2155e-08
const float c_0(2.215451521194e-08f);
const float c_1(9.999956758035e-01f);
const float c_2(-4.998600090003e-01f);
const float c_3(3.315834102478e-01f);
const float c_4(-2.389843462478e-01f);
const float c_5(1.605007787295e-01f);
const float c_6(-8.022296753549e-02f);
const float c_7(2.030898293785e-02f);

versus:

https://github.com/halide/Halide/blob/3cdeb5398fb87be699fa830f843ca5d05fe6b983/src/IROperator.cpp#L1357-L1365

mcourteaux commented 3 months ago

Apparently Windows/OpenCL on the build bot does not have a performance improvement, but even a performance degradation (about 15%):

C:\build_bot\worker\halide-testbranch-main-llvm20-x86-64-windows-cmake\halide-build\bin\performance_fast_arctan.exe
atan: 6.347030 ns per pixel
fast_atan: 7.295760 ns per pixel
atan2: 0.923191 ns per pixel
fast_atan2: 0.926148 ns per pixel
fast_atan more than 10% slower than atan on GPU.

Suggestions?

mcourteaux commented 3 months ago

GPU performance test was severely memory bandwidth limited. This has been worked around by computing many (1024) arctans per output and summing them. Now --at least on my system-- they are faster. See updated performance reports.

mcourteaux commented 3 months ago

Okay, this is ready for review. Vulkan is slow, but that is apparently known well...

mcourteaux commented 3 months ago

Oh dear... I don't even know what WebGPU is... @steven-johnson Is this supposed to be an actual platform that is fast, and where performance metrics make sense? I can treat it like Vulkan, where it's just "meh, at least some are faster..."?

steven-johnson commented 3 months ago

Oh dear... I don't even know what WebGPU is... @steven-johnson Is this supposed to be an actual platform that is fast, and where performance metrics make sense? I can treat it like Vulkan, where it's just "meh, at least some are faster..."?

https://en.wikipedia.org/wiki/WebGPU https://www.w3.org/TR/webgpu/ https://github.com/gpuweb/gpuweb/wiki/Implementation-Status

derek-gerstmann commented 3 months ago

Okay, this is ready for review. Vulkan is slow, but that is apparently known well...

I don't think Vulkan is necessarily slow ... I think the benchmark loop is including initialization overhead. See my follow up here: https://github.com/halide/Halide/issues/7202

abadams commented 3 months ago

Very cool! I have some concerns with the error metric though. Decimal digits of error isn't a great metric. E.g. having a value of 0.0001 when it's supposed to be zero is much much worse than having a value of 0.3701 when it's supposed to be 0.37. Relative error isn't great either, due to the singularity at zero. A better metric is ULPs, which is the maximum number of distinct floating point values in between the answer and the correct answer.

There are also cases where you want a hard constraint as opposed to a minimization. exp(0) should be exactly one, and I guess I decided its derivative should be exactly one too, which explains the different in coefficients.

mcourteaux commented 3 months ago

A better metric is ULPs, which is the maximum number of distinct floating point values in between the answer and the correct answer.

@abadams I improved the optimization script a lot. I added support for ULP optimization: it optimizes very nicely for maximal bit error.

atan_6_mulpe

When instead optimizing for MAE, we see the max ULP distance increase:

atan_6_mae

I changed the default to the ULP-optimized one, but to keep the maximal absolute error under 1e-5, I had to choose the higher-degree polynomial. Overall still good.

@derek-gerstmann Thanks a lot for investigating the performance issue! I now also get very fast Vulkan performance. I wonder why the overhead is so huge in Vulkan, and not there in other backends?

Vulkan:

              atan: 0.009071 ns per atan
 fast_atan (Poly2): 0.005076 ns per atan (44.0% faster)  [per invokation: 0.340618 ms]
 fast_atan (Poly3): 0.005279 ns per atan (41.8% faster)  [per invokation: 0.354284 ms]
 fast_atan (Poly4): 0.005484 ns per atan (39.5% faster)  [per invokation: 0.368018 ms]
 fast_atan (Poly5): 0.005925 ns per atan (34.7% faster)  [per invokation: 0.397631 ms]
 fast_atan (Poly6): 0.006225 ns per atan (31.4% faster)  [per invokation: 0.417756 ms]
 fast_atan (Poly7): 0.006448 ns per atan (28.9% faster)  [per invokation: 0.432734 ms]
 fast_atan (Poly8): 0.006765 ns per atan (25.4% faster)  [per invokation: 0.453989 ms]

              atan2: 0.013717 ns per atan2
 fast_atan2 (Poly2): 0.007812 ns per atan2 (43.0% faster)  [per invokation: 0.524279 ms]
 fast_atan2 (Poly3): 0.007604 ns per atan2 (44.6% faster)  [per invokation: 0.510290 ms]
 fast_atan2 (Poly4): 0.008016 ns per atan2 (41.6% faster)  [per invokation: 0.537952 ms]
 fast_atan2 (Poly5): 0.008544 ns per atan2 (37.7% faster)  [per invokation: 0.573364 ms]
 fast_atan2 (Poly6): 0.008204 ns per atan2 (40.2% faster)  [per invokation: 0.550533 ms]
 fast_atan2 (Poly7): 0.008757 ns per atan2 (36.2% faster)  [per invokation: 0.587663 ms]
 fast_atan2 (Poly8): 0.008629 ns per atan2 (37.1% faster)  [per invokation: 0.579092 ms]
Success!

CUDA:

              atan: 0.010663 ns per atan
 fast_atan (Poly2): 0.006854 ns per atan (35.7% faster)  [per invokation: 0.459946 ms]
 fast_atan (Poly3): 0.006838 ns per atan (35.9% faster)  [per invokation: 0.458894 ms]
 fast_atan (Poly4): 0.007196 ns per atan (32.5% faster)  [per invokation: 0.482914 ms]
 fast_atan (Poly5): 0.007646 ns per atan (28.3% faster)  [per invokation: 0.513141 ms]
 fast_atan (Poly6): 0.008205 ns per atan (23.1% faster)  [per invokation: 0.550595 ms]
 fast_atan (Poly7): 0.008496 ns per atan (20.3% faster)  [per invokation: 0.570149 ms]
 fast_atan (Poly8): 0.009008 ns per atan (15.5% faster)  [per invokation: 0.604508 ms]

              atan2: 0.014594 ns per atan2
 fast_atan2 (Poly2): 0.009409 ns per atan2 (35.5% faster)  [per invokation: 0.631451 ms]
 fast_atan2 (Poly3): 0.009957 ns per atan2 (31.8% faster)  [per invokation: 0.668201 ms]
 fast_atan2 (Poly4): 0.010289 ns per atan2 (29.5% faster)  [per invokation: 0.690511 ms]
 fast_atan2 (Poly5): 0.010255 ns per atan2 (29.7% faster)  [per invokation: 0.688207 ms]
 fast_atan2 (Poly6): 0.010748 ns per atan2 (26.4% faster)  [per invokation: 0.721268 ms]
 fast_atan2 (Poly7): 0.011497 ns per atan2 (21.2% faster)  [per invokation: 0.771529 ms]
 fast_atan2 (Poly8): 0.011326 ns per atan2 (22.4% faster)  [per invokation: 0.760067 ms]
Success!

Vulkan is now even faster than CUDA! 🤯

mcourteaux commented 3 months ago

@steven-johnson The build just broke on something LLVM related it seems... There seems to be no related commit to Halide. Does LLVM constantly update with every build?

Edit: I found the commit: https://github.com/llvm/llvm-project/commit/75c7bca740935a0cca462e28475dd6b046a6872c

Fix separately PR'd in #8391

steven-johnson commented 3 months ago

@steven-johnson The build just broke on something LLVM related it seems... There seems to be no related commit to Halide. Does LLVM constantly update with every build?

We rebuild LLVM once a day, about 2AM Pacific time.

mcourteaux commented 3 months ago

@abadams I added the check that counts number of wrong mantissa bits:

Testing for precision 1.0e-02 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 4.96906e-03  max mantissa bits wrong: 19
    Testing fast_atan2() correctness...  Passed: max abs error: 4.96912e-03  max mantissa bits wrong: 19

Testing for precision 1.0e-03 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 6.10709e-04  max mantissa bits wrong: 17
    Testing fast_atan2() correctness...  Passed: max abs error: 6.10709e-04  max mantissa bits wrong: 17

Testing for precision 1.0e-04 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 8.16584e-05  max mantissa bits wrong: 14
    Testing fast_atan2() correctness...  Passed: max abs error: 8.17776e-05  max mantissa bits wrong: 14

Testing for precision 1.0e-05 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.78814e-06  max mantissa bits wrong: 9
    Testing fast_atan2() correctness...  Passed: max abs error: 1.90735e-06  max mantissa bits wrong: 9

Testing for precision 1.0e-06 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 3.57628e-07  max mantissa bits wrong: 6
    Testing fast_atan2() correctness...  Passed: max abs error: 4.76837e-07  max mantissa bits wrong: 7

Testing for precision 1.0e-02 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.31637e-03  max mantissa bits wrong: 15
    Testing fast_atan2() correctness...  Passed: max abs error: 1.31637e-03  max mantissa bits wrong: 15

Testing for precision 1.0e-03 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.54853e-04  max mantissa bits wrong: 12
    Testing fast_atan2() correctness...  Passed: max abs error: 1.54972e-04  max mantissa bits wrong: 12

Testing for precision 1.0e-04 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 2.53320e-05  max mantissa bits wrong: 9
    Testing fast_atan2() correctness...  Passed: max abs error: 2.55108e-05  max mantissa bits wrong: 9

Testing for precision 1.0e-05 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 3.63588e-06  max mantissa bits wrong: 6
    Testing fast_atan2() correctness...  Passed: max abs error: 3.81470e-06  max mantissa bits wrong: 6

Testing for precision 1.0e-06 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 5.96046e-07  max mantissa bits wrong: 4
    Testing fast_atan2() correctness...  Passed: max abs error: 7.15256e-07  max mantissa bits wrong: 4
Success!

Pay attention to the MULPE optimized ones: they are significantly lower than the MAE optimized ones.

steven-johnson commented 3 months ago

Ping to @abadams or @zvookin for review

mcourteaux commented 1 month ago

Cut polynomial + merge it + later take care of other transcendentals.

mcourteaux commented 1 week ago

@abadams I updated the PR, and believe this is a nice compromise of options. It is in line with your initial thoughts on just specifying the precision yourself. I have made a table of approximations and their precisions. Then a new auxiliary function selects an approximation from that table that satisfies your requirements. This clears out the header (no more one million enum options), and clears out the source file, by not having the table sitting inside of the fast_atan function.

steven-johnson commented 2 days ago

Looks like this is ready for final review... ?