tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Fix][TVMScript] Parse and print `tir_vars` of `call_tir` properly #361

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

This PR patches both the TVMScript parser and printer on handling the tir_vars parameter of call_tir.

When a TIR function contains symbolic shapes, its signature can have trailing TIR vars in the param list. In such cases, when using call_tir to call into the TIR function, we should provide the instances of those TIR vars through the param tir_vars. tir_vars is the last parameter of call_tir, one behind dtype. https://github.com/tlc-pack/relax/blob/697675bf8f137488e0f38237b046ce671cdbad9f/python/tvm/relax/op/base.py#L45-L51

The specific issues fixed in this PR is:

Therefore, this PR fixes both sides. On printer side, we now print tir_vars after dtype and explicitly print "tir_vars=". On call_tir side, we wrap the input tir_vars to ShapeExpr if we find it a tuple or list. One regression test that covers both issues is provided.

cc @YuchenJin