tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Pass] Rewrite dataflow reshape #393

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

This PR introduces the dataflow reshape rewrite pass, which transform all dataflow TIR reshape bindings inside a Relax function into a call of runtime packed function relax.vm.builtin.reshape. Here “dataflow TIR reshape bindings” means

  1. the binding value is a call_tir of any PrimFunc that is essentially doing a reshape operation,
  2. the binding var is a DataflowVar, or in other words, the binding is inside a DataflowBlock and the binding var is not a block output var.

The relax.vm.builtin.reshape packed function creates a view of the input NDArray with the target shape, instead of doing data copy.

In order to fulfill this pass, this PR contains the following parts.

  1. An analysis function that detects the reshape pattern of PrimFuncs.
  2. The dataflow reshape rewrite pass itself.
  3. The implementation of relax.vm.builtin.reshape as a PackedFunc.

This PR contains unit tests for each part.