apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.81k stars 3.48k forks source link

[Bug] Te.gradient not work with complex forward workload #9715

Closed kk2049 closed 2 months ago

kk2049 commented 2 years ago

My problem

I am trying to use autoscheduler to generate CUDA source code for backward stage for NCHW winograd_conv2d. Due to some bugs in topi.cuda.conv2d_winograd.winograd_cuda, I copied some code to build my workload.

Luckily, this workload works without te.gradient and can successfully get source code for the forward stage. But when I add te.gradient, this workload no longer works and I get an error msg below: Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, ......

I am really confued now. Forward stage codegen can work proves that my workload is correct in some way. So I think this bug may caused by a bug in TVM, but I am not sure.

Maybe someone can help me find out whether it is a bug about TVM.

Thanks a lot!!!

Expected behavior

This code should find a valid schedule

Actual behavior

I get a error below when I start tunning.

Get devices for measurement successfully!
----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "bug_scheduler.py", line 189, in <module>
    task.tune(tune_option)
  File "/data/anaconda3/envs/env3.7/lib/python3.7/site-packages/tvm-0.8.0-py3.7-linux-x86_64.egg/tvm/auto_scheduler/search_task.py", line 498, in tune
    _ffi_api.AutoSchedule(search_policy, tuning_options)
  File "/data/anaconda3/envs/env3.7/lib/python3.7/site-packages/tvm-0.8.0-py3.7-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  13: TVMFuncCall
  12: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  11: tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)
  10: tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)
  9: tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)
  8: tvm::auto_scheduler::SketchPolicyNode::GenerateSketches()
  7: tvm::auto_scheduler::RuleAddCacheRead::Apply(tvm::auto_scheduler::SketchPolicyNode const&, tvm::auto_scheduler::State const&, int) const
  6: tvm::auto_scheduler::State::cache_read(int, tvm::runtime::String const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::auto_scheduler::ComputeDAG const&)
  5: tvm::auto_scheduler::CacheReadStepNode::ApplyToState(tvm::auto_scheduler::State*, tvm::auto_scheduler::ComputeDAG const&) const
  4: tvm::auto_scheduler::ComputeDAG::ReplayAndGetDAG(tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&) const
  3: tvm::auto_scheduler::ComputeDAG::ApplySteps(tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::auto_scheduler::LayoutRewriteOption) const
  2: tvm::auto_scheduler::StepApplyToSchedule(tvm::auto_scheduler::Step const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*, tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&)
  1: tvm::auto_scheduler::CacheReadStepNode::ApplyToSchedule(tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*) const
  0: tvm::te::Schedule::cache_read(tvm::te::Tensor const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Array<tvm::te::Operation, void> const&)
  File "/data/apache-tvm-src-v0.8.0.rc0/src/te/schedule/schedule_dataflow_rewrite.cc", line 168
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, body=[extracted_tensor[ax0, ax1, ax2, ax3]], axis=[iter_var(ax0, range(min=0, ext=2)), iter_var(ax1, range(min=0, ext=2)), iter_var(ax2, range(min=0, ext=4)), iter_var(ax3, range(min=0, ext=4))], reduce_axis=[], tag=, attrs={})

Environment

My system is Ubuntun16.04 CUDA version is 10.2 My tvm version is 0.8.0. I build it with the source code from Download Apache TVM Source Code web page.

Steps to reproduce

I am sorry about put such a long code here to make sure this bug can be reproduced. I have tried to cut out some part of my code to reproduce this error, but this bug can only be triggered by this long code.

import os

import numpy as np
import tvm
from tvm import auto_scheduler

import logging
from tvm import te, topi
from tvm import autotvm

from tvm.topi import nn
from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline
from tvm.topi.nn.winograd_util import winograd_transform_matrices
from tvm.topi.nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl
import sys
import numpy as np
from tvm.topi.testing import conv2d_nchw_python

def _infer_tile_size(data, kernel, layout="NCHW"):
    if layout == "NCHW":
        N, CI, H, W = get_const_tuple(data.shape)
    else:
        assert layout == "NHWC"
        N, H, W, CI = get_const_tuple(data.shape)

    if H % 8 == 0:
        return 4
    return 2

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")

    stride = (1,1)
    padding = (1,1)
    dilation = (1,1)
    pre_computed = False
    out_dtype = "float32"

    tile_size = _infer_tile_size(data, kernel)
    N, CI, H, W = get_const_tuple(data.shape)

    if isinstance(N, tvm.tir.Any):
        N = tvm.te.size_var("n")

    if not isinstance(H, int) or not isinstance(W, int):
        raise RuntimeError(
            "cuda winograd conv2d doesn't support dynamic input\
                           height or width."
        )

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation
    HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride

    if not pre_computed:  # kernel tensor is raw tensor, do strict check
        if dilation_h != 1 or dilation_w != 1:
            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
        CO, CI, KH, KW = get_const_tuple(kernel.shape)
        alpha = KW + tile_size - 1
        assert HSTR == 1 and WSTR == 1 and KH == KW
    else:
        # kernel tensor is pre-transfomred. this op is created by alter op layout.
        # dilation is not supported
        alpha, _, CI, CO = get_const_tuple(kernel.shape)
        KH = KW = alpha + 1 - tile_size
        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1

    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
    data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")

    r = KW
    m = tile_size
    A, B, G = winograd_transform_matrices(m, r, out_dtype)

    H = (H + pt + pb - KH) // HSTR + 1
    W = (W + pl + pr - KW) // WSTR + 1
    nH, nW = (H + m - 1) // m, (W + m - 1) // m

    P = N * nH * nW if isinstance(N, int) else nH * nW

    # transform kernel
    if not pre_computed:
        r_kh = te.reduce_axis((0, KH), name="r_kh")
        r_kw = te.reduce_axis((0, KW), name="r_kw")
        kernel_pack = te.compute(
            (alpha, alpha, CI, CO),
            lambda eps, nu, ci, co: te.sum(
                kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
            ),
            name="my_kernel_pack",
        )
    else:
        kernel_pack = kernel    

    idxdiv = tvm.tir.indexdiv
    idxmod = tvm.tir.indexmod
    # pack input tile
    input_tile = te.compute(
        (CI, P, alpha, alpha),
        lambda c, p, eps_1, nu_1: data_pad[idxdiv(p, (nH * nW))][c][
            idxmod(idxdiv(p, nW), nH) * m + eps_1
        ][idxmod(p, nW) * m + nu_1],
        name="my_d",
    )

    # dy = tvm.te.placeholder(input_tile.shape, name="input2_dy")
    # [dw] = tvm.te.gradient(input_tile, [data], head=dy)
    # return [data, kernel, input_tile, dy, dw]

    # transform data
    r_a = te.reduce_axis((0, alpha), "r_a")
    r_b = te.reduce_axis((0, alpha), "r_b")
    data_pack = te.compute(
        (alpha, alpha, CI, P),
        lambda eps, nu, ci, p: te.sum(
            input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
        ),
        name="my_data_pack",
    )

    # dy = tvm.te.placeholder(data_pack.shape, name="input2_dy")
    # [dw] = tvm.te.gradient(data_pack, [data], head=dy)
    # return [data, kernel, data_pack, dy, dw]

    # do batch gemm
    ci = te.reduce_axis((0, CI), name="ci")
    bgemm = te.compute(
        (alpha, alpha, CO, P),
        lambda eps, nu, co, p: te.sum(
            kernel_pack[eps][nu][ci][co] * data_pack[eps][nu][ci][p], axis=[ci]
        ),
        name="my_bgemm",
    )
    # inverse transform
    r_a_2 = te.reduce_axis((0, alpha), "r_a_2")
    r_b_2 = te.reduce_axis((0, alpha), "r_b_2")
    inverse = te.compute(
        (CO, P, m, m),
        lambda co, p, vh, vw: te.sum(
            bgemm[r_a_2][r_b_2][co][p] * A[r_a_2][vh] * A[r_b_2][vw], axis=[r_a_2, r_b_2]
        ),
        name="my_inverse",
    )

    # output
    output = te.compute(
        (N, CO, H, W),
        lambda n, co, h, w: inverse[
            co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
        ],
        name="my_output",
        tag="conv2d_nchw_winograd",
    )

    dy = tvm.te.placeholder(output.shape, name="input2_dy")
    [dw] = tvm.te.gradient(output, [data], head=dy)
    return [data, kernel, output,dy,dw]
    # return [data, kernel, output]

target = tvm.target.Target("cuda")

# Use the last layer in ResNet-50
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

log_file = "conv2d.json"
if os.path.exists(log_file):
    os.remove(log_file)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
kk2049 commented 2 years ago

@comaniac Sorry to bother you. (I really appreciate your help about te.gradient months ago #8991 ) I wonder if I can get your help again about this problem. I am confusing about this bug and have no idea how to fix it. Thanks a lot!!

comaniac commented 2 years ago

It looks like auto-scheduler has issues when generating the schedule sketch for this workload. You could first try to build and run this workload on CPU without tuning to see if we could identify the problem. If that doesn't work, then it must be something wrong with the workload or te.gradient. Otherwise, we could investigate the compute DAG to see why auto-scheduler failed to work on this workload generated by te.gradient.

kk2049 commented 2 years ago

@comaniac Thanks for your reply! I have tried to run this workload with tvm.target.Target("llvm"). This workload can be successfully launched. So I select target("cuda") again and tried to print the compute DAG. It looks like this:

Computational DAG:
kernel = PLACEHOLDER [512, 512, 3, 3]
G(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 3) == 2)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 3) == 1)),  ..(OMITTED).. (floormod(i, 4) == 0) && (floormod(j, 3) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 3) == 0)), 1f, 0f))))))))))))
my_kernel_pack(eps, nu, ci, co) += ((kernel[co, ci, r_kh, r_kw]*G[eps, r_kh])*G[nu, r_kw])
data = PLACEHOLDER [1, 512, 7, 7]
data_pad(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 8)) && (i3 >= 1)) && (i3 < 8)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
my_d(c, p, eps_1, nu_1) = data_pad[floordiv(p, 16), c, ((floormod(floordiv(p, 4), 4)*2) + eps_1), ((floormod(p, 4)*2) + nu_1)]
B(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 4) == 3)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 4) == 2)),  ..(OMITTED).. ormod(i, 4) == 0) && (floormod(j, 4) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 4) == 0)), 1f, 0f))))))))))))))))
my_data_pack(eps, nu, ci, p) += ((my_d[ci, p, r_a, r_b]*B[r_a, eps])*B[r_b, nu])
my_bgemm(eps, nu, co, p) += (my_kernel_pack[eps, nu, ci, co]*my_data_pack[eps, nu, ci, p])
A(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 2) == 1)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 2) == 0)),  ..(OMITTED).. ct(((floormod(i, 4) == 0) && (floormod(j, 2) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 2) == 0)), 1f, 0f))))))))
my_inverse(co, p, vh, vw) += ((my_bgemm[r_a_2, r_b_2, co, p]*A[r_a_2, vh])*A[r_b_2, vw])
my_output(n, co, h, w) = my_inverse[co, ((((n*4)*4) + (floordiv(h, 2)*4)) + floordiv(w, 2)), floormod(h, 2), floormod(w, 2)]
input2_dy = PLACEHOLDER [1, 512, 7, 7]
my_output.my_inverse.grad(ax0, ax1, ax2, ax3) = select((((((((ax2*4) + (floordiv((7 + (ax1*-2)), 8)*-8)) <= 24) && (((ax1*-2) +  ..(OMITTED).. ) <= 15)), input2_dy[0, ax0, (ax2 + (floordiv((7 + (ax1*-2)), 8)*-2)), (((floordiv((7 + (ax1*-2)), 8)*8) + (ax1*2)) + ax3)], 0f)
extracted_tensor(n0_n0_vh.shifted.shifted, n1_n1_vw.shifted.shifted, n2_n2_jac_i0.shifted.shifted, n3_n3_jac_i1.shifted.shifted) = (A[n2_n2_jac_i0.shifted.shifted, n0_n0_vh.shifted.shifted]*A[n3_n3_jac_i1.shifted.shifted, n1_n1_vw.shifted.shifted])
my_inverse.my_bgemm.grad(ax0, ax1, ax2, ax3) += (my_output.my_inverse.grad[ax2, ax3, n0_n0_k2.shifted.shifted, n1_n1_k3.shifted.shifted]*extracted_tensor[n0_n0_k2.shifted.shifted, n1_n1_k3.shifted.shifted, ax0, ax1])
my_bgemm.my_data_pack.grad(ax0, ax1, ax2, ax3) += (my_inverse.my_bgemm.grad[ax0, ax1, n0_n0_k2.shifted.shifted, ax3]*my_kernel_pack[ax0, ax1, ax2, n0_n0_k2.shifted.shifted])
extracted_tensor(n0_n0_eps.shifted.shifted, n1_n1_nu.shifted.shifted, n4_n4_jac_i2.shifted.shifted, n5_n5_jac_i3.shifted.shifted) = (B[n4_n4_jac_i2.shifted.shifted, n0_n0_eps.shifted.shifted]*B[n5_n5_jac_i3.shifted.shifted, n1_n1_nu.shifted.shifted])
my_data_pack.my_d.grad(ax0, ax1, ax2, ax3) += (my_bgemm.my_data_pack.grad[n0_n0_k0.shifted.shifted, n1_n1_k1.shifted.shifted, ax0, ax1]*extracted_tensor[n0_n0_k0.shifted.shifted, n1_n1_k1.shifted.shifted, ax2, ax3])
data_pad.data.grad(ax0, ax1, ax2, ax3) += my_data_pack.my_d.grad[ax1, (((((floordiv((ax2 + 1), 2) + n0_n0_fdiv1.shifted.shifted) ..(OMITTED).. ormod((ax2 + 1), 2) + (n0_n0_fdiv1.shifted.shifted*-2)) + 2), ((floormod((ax3 + 1), 2) + (n1_n1_fmod1.shifted.shifted*-2)) + 2)]

I have tried to check this DAG info myself but failed to anything useful. Maybe you can find something in it?

Thanks a lot for your help!!!