tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Layout] Add layout transformation analysis for PrimFunc #449

Closed psrivas2 closed 1 year ago

psrivas2 commented 1 year ago

This change adds a PrimFunc level analysis to propose layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs. It analyzes buffer access to figure this out. It tries to preserve sequential access to buffers when it does this.

For example given the following PrimFunc and write buffer "relu" transformation lambda n, c, h, w: (n, h, w, c // 4, c % 4), it will suggest to make the following transformations.

@T.prim_func
def elemwise_relu(
    arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
    relu: T.Buffer((32, 224, 224, 16, 4), "float32"),
):
    for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4):
        with T.block("compute"):
            v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
            T.reads(arg[v0, v1, v2, v3, v4])
            T.writes(relu[v0, v1, v2, v3, v4])
            relu[v0, v1, v2, v3, v4] = T.max(arg[v0, v1, v2, v3, v4], T.float32(0))
tqchen commented 1 year ago

https://github.com/tlc-pack/relax/issues/453, this PR should be part of PRs to send to unity. We can either directly transition to PR to unity, or continue review merge before transition, depending on what authors and reviewers want

psrivas2 commented 1 year ago

cc @sunggg @masahi @YuchenJin

psrivas2 commented 1 year ago

Migrated to unity branch https://github.com/apache/tvm/pull/14066