apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.82k stars 3.48k forks source link

[TIR][Schedule] Add annotate_buffer_access primitive #17423

Closed qsqqsqqsq-intellif closed 1 month ago

qsqqsqqsq-intellif commented 2 months ago

Overview

This PR introduces a new TIR schedule primitive annotate_buffer_access that allows explicit annotation of buffer access regions for both reads and writes.

Motivation

TVM currently does not support inferring the numerical range of floating-point calculations. As a result, buffer access regions involving floating-point calculations cannot be accurately inferred and default to the full extent of the buffer. This new primitive addresses this limitation by allowing manual specification of access regions.

Usage scenarios

This primitive is particularly useful for operations where the default buffer region inference may not capture the precise access patterns, such as in resize operations. It overrides the automatically inferred region for the specified buffer.

Example

Trivial Example

before:

   @T.prim_func
    def before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")):
        for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, 0:32, 0:32])
                T.writes(resize[v_i0, v_i1, v_i2, v_i3])
                resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)]))

Perform annotate_buffer_access:

    sch.annotate_buffer_access(block, 0, "read",
        gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [
            v_i0,
            v_i1,
            (v_i2 * 2 - 3, v_i2 * 2 + 3),
            (v_i3 * 2 - 3, v_i3 * 2 + 3),
        ],
    )

after:

     @T.prim_func
    def after(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")):
        for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3])
                T.writes(resize[v_i0, v_i1, v_i2, v_i3])
                T.block_attr({"explicit_read_region": [0]})
                resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)]))

The primitive adds an annotation(T.block_attr({"explicit_read_region": [0]})) to the block, indicating that an explicit region has been provided for the buffer at the given index. This annotation is used in the CompactBufferAllocation pass to respect the manually specified region instead of relying on automatic inference.

Resize Op Tile Example

We can optimize the tiling of the "cache" block for the "resize" operation using the annotate_buffer_access primitive. before:

    @T.prim_func
    def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32")
        for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200):
            with T.block("cache"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
        for i0, i1, i2, i3 in T.grid(1, 3, 100, 100):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))]

Let's split the i2 loop and i3 loop of the "resize" block, and then compute-at "cache" block to outer loop of resize. This is a typical schedule of tile process.

    h, w = s.get_loops(resize_block)[-2:]
    ho, hi = s.split(h, factors=[10, 10])
    wo, wi = s.split(w, factors=[10, 10])
    s.reorder(ho, wo, hi, wi)
    s.compute_at(cache_block, wo)

After tiling without annotate_buffer_access:

    @T.prim_func
    def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer((1, 3, 200, 200))
        for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10):
            for ax0, ax1 in T.grid(200, 200):
                with T.block("cache"):
                    v0 = T.axis.spatial(1, 0)
                    v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1])
                    T.reads(x[v0, v1, v2, v3])
                    T.writes(x_global[v0, v1, v2, v3])
                    x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
            for i2_1, i3_1 in T.grid(10, 10):
                with T.block("resize"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1)
                    v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1)
                    T.reads(x_global[v_i0, v_i1, 0:200, 0:200])
                    T.writes(y[v_i0, v_i1, v_i2, v_i3])
                    y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))]

Notice that the "cache" block still reads the entire 200x200 region after compute-at. To optimize this, we can use annotate_buffer_access to explicitly annotate the buffer region of the "resize" block:

    s.annotate_buffer_access(
        resize_block,
        0,
        "read",
        lambda vn, vc, vh, vw: (vn, vc, (vh * 2 - 3, vh * 2 + 3), (vw * 2 - 3, vw * 2 + 3)),
    )
    s.compute_at(cache_block, wo)

After tiling with annotate_buffer_access:

    @T.prim_func
    def after_with_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer((1, 3, 200, 200))
        for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10):
            for ax0, ax1 in T.grid(24, 24):
                with T.block("cache"):
                    v0 = T.axis.spatial(1, 0)
                    v1 = T.axis.spatial(3, i1)
                    v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0)
                    v3 = T.axis.spatial(200, i3_0 * 20 - 3 + ax1)
                    T.where(3 <= i2_0 * 20 + ax0 and i2_0 * 20 + ax0 < 203 and 3 <= i3_0 * 20 + ax1 and i3_0 * 20 + ax1 < 203)
                    T.reads(x[v0, v1, v2, v3])
                    T.writes(x_global[v0, v1, v2, v3])
                    x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
            for i2_1, i3_1 in T.grid(10, 10):
                with T.block("resize"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1)
                    v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1)
                    T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6])
                    T.writes(y[v_i0, v_i1, v_i2, v_i3])
                    T.block_attr({"explicit_read_region": [0]})
                    y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))]

The "cache" block now only reads the necessary 24x24 region instead of the entire 200x200 input. These optimizations significantly reduce memory bandwidth requirements and improve cache efficiency, especially for larger input sizes.

Note

Caution should be exercised when using this function, as incorrect annotations may lead to incorrect code generation or runtime errors. It's crucial to ensure that the specified region covers all actual reads or writes performed by the block for the given buffer. cc @Hzfengsy @junrushao