apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.81k stars 3.48k forks source link

[Bug] [Relax] Build fails when applying `dlight.gpu.GeneralReduction` to `R.nn.group_norm` with dynamic shapes and `R.reshape` #17531

Open Yumin-gd opened 1 week ago

Yumin-gd commented 1 week ago

Actual behavior

When building the TVMScript below using dlight.gpu.GeneralReduction(), the build fails with the following error: InternalError: Check failed: (!divisor.is_const(0)) is false: Find divide by zero

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def reshape_norm(
        inp_0: R.Tensor((1, 512, "w", "h"), dtype="float16"), 
        inp_1: R.Tensor((512,), dtype="float16"), 
        inp_2: R.Tensor((512,), dtype="float16")
        )-> R.Tensor((1, 512, "w * h"), dtype="float16"):
        w = T.int64()
        h = T.int64()
        with R.dataflow():
            lv = R.reshape(inp_0, R.shape([1, 512, w * h]))
            lv1 = R.nn.group_norm(data = lv, gamma = inp_1, beta = inp_2, num_groups=32, channel_axis=1, axes=[2], epsilon=9.9999999999999995e-07, center=True, scale=True)
            R.output(lv1)
        return lv1

Environment

Steps to reproduce

import tvm
from tvm import relax
import tvm.dlight as dl

@tvm.transform.module_pass(opt_level=0)
def dynshape_build_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
    seq = tvm.transform.Sequential(
        [
            relax.backend.DispatchSampling(),
            relax.backend.DispatchSortScan(),
            relax.transform.LegalizeOps(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
            relax.transform.RewriteDataflowReshape(),
            relax.transform.ToNonDataflow(),
            relax.transform.RemovePurityChecking(),
            relax.transform.CallTIRRewrite(),
            relax.transform.StaticPlanBlockMemory(),
            relax.transform.RewriteCUDAGraph(),
            relax.transform.LowerAllocTensor(),
            relax.transform.KillAfterLastUse(),
            relax.transform.LowerRuntimeBuiltin(),
            relax.transform.ComputePrimValue(),
            relax.transform.VMShapeLower(),
            relax.transform.AttachGlobalSymbol(),
        ],
    )
    mod = seq(mod)
    return mod

# `Module` as TVMScript in 'Actual behavior'
mod = Module
mod = relax.get_pipeline()(mod)
target = tvm.target.Target("cuda")
ex = relax.build(mod, target=target, pipeline=dynshape_build_pipeline)

cc @junrushao