halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

Crash and Poor Performance Issues with 3D Stencil Computation using autoscheduler=Adams2019 #8400

Open lsh2233 opened 2 months ago

lsh2233 commented 2 months ago

I'm experiencing issues while testing the performance of a 3D stencil computation in Halide using the autoscheduler Adams2019. The computation is defined as:

output(x, y, z) = k0 input(x - 1, y, z) + k1 input(x, y - 1, z) + k2 input(x, y, z - 1) + k3 input(x + 1, y, z) + k4 input(x, y + 1, z) + k5 input(x, y, z + 1) + k6 * input(x, y, z)

The computation works perfectly when using the Mullapudi2016 autoscheduler. However, with Adams2019, I encounter two major issues:

  1. Crash with Specific Input Shape: When the input has the shape [256, 256, 256], the program crashes.
  2. Poor Performance with Other Input Shapes: For other input shapes, the computation runs extremely slowly, far slower than expected. Like: shape[512,512,512], Manually-tuned time: 9.90154 s, Auto-scheduled time: 173.521 s

Below is my Halide generator code:

#include "Halide.h"

#define k0 0.1f
#define k1 0.2f
#define k2 0.1f
#define k3 0.2f
#define k4 0.2f
#define k5 0.2f
#define k6 0.1f

const int zdim = 1024;
const int ydim = 1024;
const int xdim = 1024;

namespace
{

    class StencilChain : public Halide::Generator<StencilChain>
    {
    public:
        Input<Buffer<float, 3>> input{"input"};
        Output<Buffer<float, 3>> output{"output"};

        void generate()
        {
            Var x("x"), y("y"), z("z");

            Func clamped("clamped");
            Expr clamped_x = clamp(x, 1, input.width() - 1);
            Expr clamped_y = clamp(y, 1, input.height() - 1);
            Expr clamped_z = clamp(z, 1, input.channels() - 1);
            clamped(x, y, z) = input(clamped_x, clamped_y, clamped_z);

            // The algorithm
            output(x, y, z) = k0 * clamped(x - 1, y, z) + k1 * clamped(x, y - 1, z) + k2 * clamped(x, y, z - 1) + k3 * clamped(x + 1, y, z) + k4 * clamped(x, y + 1, z) + k5 * clamped(x, y, z + 1) + k6 * clamped(x, y, z);

            Var yi, xi, zi;

            input.set_estimates({{0, xdim}, {0, ydim}, {0, zdim}});
            output.set_estimates({{0, xdim}, {0, ydim}, {0, zdim}});
            if (using_autoscheduler())
            {
                // nothing
            }
            else
            {
                const int vec = natural_vector_size<float>();
                output
                    .split(y, y, yi, 32)
                    .parallel(z)
                    .vectorize(x, vec);
            }
        }
    };

} // namespace

HALIDE_REGISTER_GENERATOR(StencilChain, stencil_chain)

Environment: