tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Transform] RewriteDataflowReshape to op and VMBuiltinLower handling #415

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

As discussed in https://github.com/tlc-pack/relax/pull/407#discussion_r1099586500, we update the behavior of pass RewriteDataflowReshape.

In short, prior to this PR, the pass transforms calls of reshape PrimFunc in dataflow blocks to direct calls of runtime packed func “vm.builtin.reshape.” The consequence of this behavior is that the memory planning pass has to check the reshape op by string comparison of ExternFunc.global_symbol, which is not ideal.

Therefore, this PR changes the RewriteDataflowReshape’s behavior, transforming calls of reshape PrimFunc to our high-level reshape op “relax.reshape,” and let the VMBuiltinLower pass to lowers the op to calls of “vm.builtin.reshape.”

masahi commented 1 year ago

Unrelated, but the following line that assumes the callee of call_tir is GlobalVar is incorrect, since it can also be ExternFunc after RunCodegen pass in the BYOC flow.

https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/rewrite_dataflow_reshape.cc#L71

Can you add a check call->args[0]->IsInstance<ExternFuncNode>() at L68?

MasterJH5574 commented 1 year ago

Unrelated, but the following line that assumes the callee of call_tir is GlobalVar is incorrect, since it can also be ExternFunc after RunCodegen pass in the BYOC flow.

https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/rewrite_dataflow_reshape.cc#L71

Can you add a check call->args[0]->IsInstance<ExternFuncNode>() at L68?

@masahi Thanks! Sorry it was my negligence. Will add in the next commit.

MasterJH5574 commented 1 year ago

@masahi Updated :-)

slyubomirsky commented 1 year ago

We should probably document when in compilation this rewrite (and any other passes whose order is important) should happen, since there will be dependencies on it (like #407).

MasterJH5574 commented 1 year ago

We should probably document when in compilation this rewrite (and any other passes whose order is important) should happen, since there will be dependencies on it (like #407).

@slyubomirsky Thanks! Added one note here https://github.com/tlc-pack/relax/blob/3c79b61db652aef6d67527260015d1cca21a928b/include/tvm/relax/transform.h#L113-L115 https://github.com/tlc-pack/relax/blob/3c79b61db652aef6d67527260015d1cca21a928b/python/tvm/relax/transform/transform.py#L117-L122

MasterJH5574 commented 1 year ago

Unrelated, but the following line that assumes the callee of call_tir is GlobalVar is incorrect, since it can also be ExternFunc after RunCodegen pass in the BYOC flow.

https://github.com/tlc-pack/relax/blob/relax/src/relax/transform/rewrite_dataflow_reshape.cc#L71

Can you add a check call->args[0]->IsInstance<ExternFuncNode>() at L68?

@masahi Hi Masa. Sorry that I didn’t think too much on your comment before. One quick question: why we use call_tir to call an ExternFunc, while not directly use the ExternFunc as CallNode::op? call_tir is supposed to call TIR PrimFunc IMO. And if we want to call some ExternFunc, we can directly call that ExternFunc via Call(op=extern_func, args=my_args, ...).

masahi commented 1 year ago

I don't know, and I wondered too. cc @sunggg who wrote the code below.

https://github.com/tlc-pack/relax/blob/83adb870240a13ae75d0dcfac93e81bb2f3bcf59/src/relax/transform/run_codegen.cc#L84

MasterJH5574 commented 1 year ago

@masahi I just got a quick answer.

Here we use call_tir calling ExternFunc for call_dps_packed. This means the interface of the ExternFunc is in DPS style, and thus when we use call_tir to call it, we write a normal call instead of a DPS call, and don’t need to allocate the result memory in ahead.

We may need a dedicated op for call_dps_packed (#430), as using call_tir both normal call_tir and call_dps_packed is too much confusing. Nevertheless, we can leave it as it is now. Sorry for the bothering.