apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.75k stars 3.47k forks source link

Incorrect implementation of cross_entropy_with_logits #9109

Closed SWu closed 1 month ago

SWu commented 3 years ago

The implementation of cross_entropy_with_logits seems to be incorrect: https://github.com/apache/tvm/blob/main/python/tvm/relay/op/nn/_nn.py#L912

It should be something like:

-topi.sum(topi.nn.log_softmax(x) * y) / x.shape[0]

However, if I naively try to make the change above, I get the following error when trying to compile a model using it:

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0x8a) [0x7f22047564ba]
  [bt] (7) /usr/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::CallNode const*)+0x3be) [0x7f22048d981e]
  [bt] (6) /usr/tvm/build/libtvm.so(tvm::relay::OpMatch<void>::operator()(tvm::relay::Call const&)+0xef) [0x7f22048d88ff]
  [bt] (5) /usr/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::CallNode const*)::{lambda(tvm::Array<tvm::relay::Expr, void> const&, tvm::Attrs const&, tvm::Array<tvm::relay::Type, void> const&)#1}::operator()(tvm::Array<tvm::relay::Expr, void> const&, tvm::Attrs const&, tvm::Array<tvm::relay::Type, void> const&) const+0x13a) [0x7f22048d740a]
  [bt] (4) /usr/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::EmitInvokeTVMOp(tvm::relay::Function const&, tvm::relay::Expr const&, tvm::relay::Expr const&)+0x9eb) [0x7f22048d6d8b]
  [bt] (3) /usr/tvm/build/libtvm.so(tvm::relay::CompileEngineImpl::Lower(tvm::relay::CCacheKey const&)+0x20) [0x7f220487e9a0]
  [bt] (2) /usr/tvm/build/libtvm.so(tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)+0x329) [0x7f220487deb9]
  [bt] (1) /usr/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)+0xef2) [0x7f220487d412]
  [bt] (0) /usr/tvm/build/libtvm.so(+0xb8b2fb) [0x7f2204a272fb]
  File "/usr/tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
    rv = local_pyfunc(*pyargs)
  File "/usr/tvm/python/tvm/relay/op/_reduce.py", line 31, in _schedule_reduce
    return topi.generic.schedule_reduce(outs)
  File "<decorator-gen-92>", line 2, in schedule_reduce
  File "/usr/tvm/python/tvm/target.py", line 299, in dispatch_func
    return generic_func_node(*args)
  File "/usr/tvm/python/tvm/target.py", line 161, in __call__
    return _api_internal._GenericFuncCallFunc(self, *args)
  File "/usr/tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()
  [bt] (3) /usr/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f2204a2bbb1]
  [bt] (2) /usr/tvm/build/libtvm.so(+0x4c4fbf) [0x7f2204360fbf]
  [bt] (1) /usr/tvm/build/libtvm.so(tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x112) [0x7f2204360da2]
  [bt] (0) /usr/tvm/build/libtvm.so(+0xb8b2fb) [0x7f2204a272fb]
  File "/usr/tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
    rv = local_pyfunc(*pyargs)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 119, in schedule_reduce
    traverse_after_reduce(outs[0].op)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 100, in traverse_after_reduce
    traverse_after_reduce(tensor.op)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 100, in traverse_after_reduce
    traverse_after_reduce(tensor.op)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 105, in traverse_after_reduce
    traverse_before_reduce(tensor.op)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 88, in traverse_before_reduce
    traverse_before_reduce(tensor.op)
  File "/usr/tvm/topi/python/topi/x86/reduction.py", line 90, in traverse_before_reduce
    raise RuntimeError("Unsupported operator: %s" % operator.tag)
RuntimeError: Unsupported operator: log_softmax_output
SWu commented 3 years ago

FWIW, it appears that the current cross_entropy_with_logits is actually what pytorch calls NLLLoss (negative log-likelihood loss), so a workaround to actually get cross_entropy_with_logits is to cross_entropy_with_logits(log_softmax(y_pred), y_target).

SWu commented 2 years ago

bump

this is still incorrect in latest mainline: https://github.com/apache/tvm/blob/main/python/tvm/relay/op/nn/_nn.py#L1012

can we at least delete this operator in the meantime to avoid confusion by people using it expecting it to be correct?