The ComputePrimValue transform is used to compute the value of symbolic expressions that may appear within a Relax function. For example, to compute a boolean condition used for a relax::If node. These functions are used for small host-side computations, prior to launching a device kernel.
This commit updates ComputePrimValue to annotate the generated PrimFunc with tir::attr::kIsHostFunc. This annotation is required for correct behavior in tvm.dlight.ApplyDefaultSchedule, to avoid erroneous scheduling of this function for the GPU, and for tir::transform::BindTarget, to ensure that the function is compiled for execution on the host.
The
ComputePrimValue
transform is used to compute the value of symbolic expressions that may appear within a Relax function. For example, to compute a boolean condition used for arelax::If
node. These functions are used for small host-side computations, prior to launching a device kernel.This commit updates
ComputePrimValue
to annotate the generatedPrimFunc
withtir::attr::kIsHostFunc
. This annotation is required for correct behavior intvm.dlight.ApplyDefaultSchedule
, to avoid erroneous scheduling of this function for the GPU, and fortir::transform::BindTarget
, to ensure that the function is compiled for execution on the host.