Closed Hzfengsy closed 1 year ago
vm.store_shape
and vm.load_shape
are emitted by VMShapeLower pass to lower shape expressions in Relax to calculations using shape heap. See test at https://github.com/tlc-pack/relax/blob/2bb8a14d0e0c8e4fb55771bfdd0166c117d22fc0/tests/python/relax/test_transform.py#L359-L393
For example if you run the pass on the following module:
@tvm.script.ir_module
class TestVMShapeLower:
@R.function
def foo(x: R.Tensor(dtype="float32")):
m, n = T.var("int64"), T.var("int64")
R.match_shape(x, (n, m))
return (n * 2, m * 3)
after VMShapeLower
@tvm.script.ir_module
class Module:
@R.function
def foo(x: Tensor(_, "float32", ndim = -1)) -> Shape:
# block 0
shape_heap: Tensor((4,), "int64") = R.call_packed("vm.builtin.alloc_shape_heap", (4,))
# block 1
sh: Object = R.call_packed("vm.builtin.shape_of", x)
gv: Tuple() = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key="relax.attrs.ShapeHeapAttrs")
# block 2
_ = shape_func(shape_heap)
sh1: Shape = relax.vm.builtin.load_shape(shape_heap, indices=[2, 3], attrs_type_key="relax.attrs.ShapeHeapAttrs")
return sh1
@T.prim_func
def shape_func(H: T.Buffer[T.int64(4), "int64"]):
# function attr dict
T.func_attr({"global_symbol": "shape_func"})
# body
H[2] = H[0] * T.int64(2)
H[3] = H[1] * T.int64(3)
Thanks @psrivas2 for the information. However, these codes are removed at https://github.com/tlc-pack/relax/pull/324.
I'm not sure if these two ops vm.store_shape
and vm.load_shape
is still needed after the match_cast refactor.
cc @tqchen
vm.store_shape
and vm.load_shape
are no longer needed after match-cast refactor. They are replaced by builtin match_shape and make_shape. We can remove the old intrinsics
This PR contains the following changes:
relax.op.vm.builtin
torelax.op.vm
, e.g.relax.op.vm.builtin.alloc_tensor
->relax.op.vm.alloc_tensor
.vm.store_shape
andvm.load_shape
.cc @vinx13 @YuchenJin