tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[VM][TVMScript] Shorten VM op names and add TVMScript support #367

Closed Hzfengsy closed 1 year ago

Hzfengsy commented 1 year ago

This PR contains the following changes:

  1. Shorten the VM op prefix from relax.op.vm.builtin to relax.op.vm, e.g. relax.op.vm.builtin.alloc_tensor -> relax.op.vm.alloc_tensor.
  2. Add TVMScript support for VM ops.
  3. remove unused vm.store_shape and vm.load_shape.

cc @vinx13 @YuchenJin

psrivas2 commented 1 year ago

vm.store_shape and vm.load_shape are emitted by VMShapeLower pass to lower shape expressions in Relax to calculations using shape heap. See test at https://github.com/tlc-pack/relax/blob/2bb8a14d0e0c8e4fb55771bfdd0166c117d22fc0/tests/python/relax/test_transform.py#L359-L393

For example if you run the pass on the following module:

@tvm.script.ir_module
    class TestVMShapeLower:
        @R.function
        def foo(x: R.Tensor(dtype="float32")):
            m, n = T.var("int64"), T.var("int64")
            R.match_shape(x, (n, m))
            return (n * 2, m * 3)

after VMShapeLower

@tvm.script.ir_module
class Module:
    @R.function
    def foo(x: Tensor(_, "float32", ndim = -1)) -> Shape:
        # block 0
        shape_heap: Tensor((4,), "int64") = R.call_packed("vm.builtin.alloc_shape_heap", (4,))
        # block 1
        sh: Object = R.call_packed("vm.builtin.shape_of", x)
        gv: Tuple() = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key="relax.attrs.ShapeHeapAttrs")
        # block 2
        _ = shape_func(shape_heap)
        sh1: Shape = relax.vm.builtin.load_shape(shape_heap, indices=[2, 3], attrs_type_key="relax.attrs.ShapeHeapAttrs")
        return sh1

    @T.prim_func
    def shape_func(H: T.Buffer[T.int64(4), "int64"]):
        # function attr dict
        T.func_attr({"global_symbol": "shape_func"})
        # body
        H[2] = H[0] * T.int64(2)
        H[3] = H[1] * T.int64(3)
Hzfengsy commented 1 year ago

Thanks @psrivas2 for the information. However, these codes are removed at https://github.com/tlc-pack/relax/pull/324.

I'm not sure if these two ops vm.store_shape and vm.load_shape is still needed after the match_cast refactor.

cc @tqchen

tqchen commented 1 year ago

vm.store_shape and vm.load_shape are no longer needed after match-cast refactor. They are replaced by builtin match_shape and make_shape. We can remove the old intrinsics