apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.66k stars 3.45k forks source link

[Bug] [Metaschedule] Cannot tune dense with batch > 1 and not divisible by 16 for cuda #14935

Open elvin-n opened 1 year ago

elvin-n commented 1 year ago

Unable to tune dense/matmul having M value bigger than 1. In case of dense it is batch dimension. Tuning on CPU passes well

Update: if M is dividable by 16, tuning works

Code to reproduce:

import numpy as np
import tvm
from tvm import relay
from tvm import meta_schedule as ms
from tvm.relay.backend import Executor

# -------- Func definition
dtype = "float16"
input_shape = (4, 2048) # if shape here is (1, 2048), it will be able to be tuned for cuda
filter_shape = (768, 2048)
A = relay.var("data", shape=input_shape, dtype=dtype)
W = relay.var("w1", shape=filter_shape, dtype=dtype)
dense = relay.nn.dense(A, W, out_dtype=dtype)

mod = relay.Function([A, W], dense)
np.random.seed(0)
filter_data1 = np.zeros(filter_shape).astype(dtype)
params1 = {
    "w1": tvm.nd.array(filter_data1),
}

from tvm.ir import IRModule
mod = IRModule.from_expr(mod)

# ------ Tune through metascheduler
database = None

strategy_name = "evolutionary"
name = "dense_4_2048_2048_768"
work_dir = f"./{name}/"
module_equality_name = "ignore-ndarray"
strategy_name = "evolutionary"

target_llvm = tvm.target.Target("nvidia/geforce-rtx-2060", host="llvm")
executor = Executor("graph")
mod = mod.with_attr("executor", executor)
ndk_builder = ms.builder.LocalBuilder(timeout_sec=60)
evaluator_config=ms.runner.EvaluatorConfig(
    number=3,
    repeat=1,
    min_repeat_ms=100,
    enable_cpu_cache_flush=False,
)
ms_rpc_runner = ms.runner.LocalRunner(evaluator_config=evaluator_config,
            alloc_repeat=1,
        )
ms.relay_integration.tune_relay(
    mod=mod,
    target=target_llvm,
    params=params1,
    work_dir=work_dir,
    max_trials_global=1024,
    strategy=strategy_name,
    builder=ndk_builder,
    runner=ms_rpc_runner,
    module_equality=module_equality_name,
)

cc @ibsidorenko

Hzfengsy commented 1 year ago

cc @vinx13