cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
326 stars 92 forks source link

[Primitive] Extensible primitives and transformations in Python #504

Closed chhzh123 closed 1 year ago

chhzh123 commented 1 year ago

This PR refactors the schedule primitives by decoupling them from the Schedule and Stage classes and maintaining each of them in a separate file. It makes the code structure more clear and enables users to plug in their customizations as a primitive.

The following code shows an example that creates buffers for each function argument. Users can inherit from our base Primitive class and define their own primitive by implementing the apply function. We provide several helper functions in heterocl.ir.transform to conduct program transformations, so users can write a few lines of code to create intermediate buffers or implement more complicated optimizations. After implementing the transformation, users only need to call the register_primitive decorator, and HeteroCL will automatically register it during runtime.

import heterocl as hcl
import heterocl.ir.transform as hir

@hcl.register_primitive()
class BufferRootPrimitive(hcl.Primitive):
    name = "buffer_root"

    @staticmethod
    def apply(sch):
        loops = hir.get_affine_loop_nests(sch.top_func)[0]
        for i, arg in enumerate(sch.top_func.arguments):
            hir.create_buffer(arg, f"arg_{i}", ip=loops[0][1])

In the main function, users can apply their newly defined primitives as what builtin primitives do. In this example, users can call s.buffer_root() to invoke the primitive. In this way, the primitives are more modular, extensible, and reusable, providing more optimization opportunities in the future.


def test_gemm_buffer(M=32, N=32, K=32, dtype=hcl.Int(), target=None):
    hcl.init(hcl.Float())
    A = hcl.placeholder((M, K), name="A")
    B = hcl.placeholder((K, N), name="B")

    def gemm(A, B):
        k = hcl.reduce_axis(0, K, name="k")
        C = hcl.compute((M, N), lambda i, j: hcl.sum(A[i, k] * B[k, j], axis=k), "C")
        return C

    s = hcl.create_schedule([A, B], gemm)

    # optimization
    C = gemm.C
    s[C].reorder(C.axis[1], C.axis[0])
    s.buffer_root()
    hcl.build(s, target="vhls")
    print(s.module)
chhzh123 commented 1 year ago

Thanks @zzzDavid !