apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

symbol.contrib.cond does not support custom operator execution #16187

Open igolan opened 5 years ago

igolan commented 5 years ago

Description

symbol.contrib.cond operator does not support custom operator execution.

Environment info (Required)

----------Python Info----------
Version      : 3.7.4
Compiler     : Clang 10.0.1 (clang-1001.0.46.4)
Build        : ('default', 'Jul  9 2019 18:13:23')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 19.0.3
Directory    : /Users/XX/PycharmProjects/XX/venv/lib/python3.7/site-packages/pip-19.0.3-py3.7.egg/pip
----------MXNet Info-----------
Version      : 1.5.0
Directory    : /Users/XX/PycharmProjects/XX/venv/lib/python3.7/site-packages/mxnet
Commit Hash   : 75a9e187d00a8b7ebc71412a02ed0e3ae489d91f
Library      : ['/Users/XX/PycharmProjects/XX/venv/lib/python3.7/site-packages/mxnet/libmxnet.so']
Build features:
✖ CUDA
✖ CUDNN
✖ NCCL
✖ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✖ OPENMP
✖ SSE
✖ F16C
✖ JEMALLOC
✖ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✖ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✔ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
----------System Info----------
Platform     : Darwin-18.7.0-x86_64-i386-64bit
system       : Darwin
node         : XXX
release      : 18.7.0
version      : Darwin Kernel Version 18.7.0: Tue Aug 20 16:57:14 PDT 2019; root:xnu-4903.271.2~2/RELEASE_X86_64
----------Hardware Info----------
machine      : x86_64
processor    : i386
b'machdep.cpu.brand_string: Intel(R) Core(TM) i7-7660U CPU @ 2.50GHz'
b'machdep.cpu.features: FPU VME DE PSE TSC MSR PAE MCE CX8 APIC SEP MTRR PGE MCA CMOV PAT PSE36 CLFSH DS ACPI MMX FXSR SSE SSE2 SS HTT TM PBE SSE3 PCLMULQDQ DTES64 MON DSCPL VMX SMX EST TM2 SSSE3 FMA CX16 TPR PDCM SSE4.1 SSE4.2 x2APIC MOVBE POPCNT AES PCID XSAVE OSXSAVE SEGLIM64 TSCTMR AVX1.0 RDRAND F16C'
b'machdep.cpu.leaf7_features: RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 HLE AVX2 SMEP BMI2 ERMS INVPCID RTM FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT MDCLEAR TSXFA IBRS STIBP L1DF SSBD'
b'machdep.cpu.extfeatures: SYSCALL XD 1GBPAGE EM64T LAHF LZCNT PREFETCHW RDTSCP TSCI'
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0137 sec, LOAD: 0.5112 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0180 sec, LOAD: 0.4525 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.0198 sec, LOAD: 0.8612 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0233 sec, LOAD: 0.1894 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0120 sec, LOAD: 0.3173 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0105 sec, LOAD: 0.0961 sec.
----------Environment----------

I'm using Pyton

Build info (Required if built from source)

N/A

Error Message:

Traceback (most recent call last):
  File "_ctypes/callbacks.c", line 232, in 'calling callback function'
  File "/Users/XX/PycharmProjects/XX/venv/lib/python3.7/site-packages/mxnet/operator.py", line 718, in creator
    op_prop = prop_cls(**kwargs)
TypeError: __init__() got an unexpected keyword argument '__subgraph_name__'

Segmentation fault: 11

Stack trace:
  [bt] (0) 1   libmxnet.so                         0x000000011705c2b0 mxnet::Storage::Get() + 4880
  [bt] (1) 2   libsystem_platform.dylib            0x00007fff57f9eb5d _sigtramp + 29
  [bt] (2) 3   Python                              0x000000010dcd7194 _PyMethodDef_RawFastCallDict + 591
  [bt] (3) 4   libmxnet.so                         0x0000000115698206 mxnet::NDArray::set_aux_shape(unsigned long, mxnet::TShape const&) const + 177878
  [bt] (4) 5   libmxnet.so                         0x00000001174c6fee NNSymbolCompose + 89646
  [bt] (5) 6   libmxnet.so                         0x00000001168eb3d6 MXSymbolCreateAtomicSymbol + 4086
  [bt] (6) 7   _ctypes.cpython-37m-darwin.so       0x000000010e1c636f ffi_call_unix64 + 79
  [bt] (7) 8   ???                                 0x00007ffee1f469d0 0x0 + 140732689312208

Minimum reproducible example

import mxnet as mx
from mxnet import nd, autograd, gluon

class IdentityOP(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        self.assign(out_data[0], req[0], in_data[0])

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register("identityop")
class IdentityOPProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(IdentityOPProp, self).__init__(True)

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return IdentityOP()

class MLP(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        with self.name_scope():
            self.dense1 = gluon.nn.Dense(1, in_units=1)

    def hybrid_forward(self, F, x):
        # Not working:
        cond_out = F.contrib.cond(F.ones(1) == F.ones(1), lambda: self.dense1(x), lambda: mx.symbol.Custom(data=x, name='identityop', op_type='identityop'))
        # Working:
        # cond_out = F.contrib.cond(F.ones(1) == F.ones(1), lambda: self.dense1(x), lambda: x)
        return cond_out

model_ctx = mx.cpu()
net = MLP()
net.hybridize()
net.collect_params().initialize(mx.init.Constant([1]), ctx=model_ctx)
data = nd.ones((3,1))
with mx.autograd.record():
    out = net(data.as_in_context(model_ctx))
out.backward()
print(net.dense1.weight.grad())
with mx.autograd.record():
    out = net(data.as_in_context(model_ctx))
out.backward()
print(net.dense1.weight.grad())

Steps to reproduce

Run code above

What have you tried to solve it?

  1. Replace custom operator with no-operator (or built-in operator) - works (see comment in hybrid_forward

I suspect it has something to do with custom operators being executed imperatively (?) Might be related to #12154 , #11641 and #16182 .

*I'm not sure that the custom operator implementation is not missing something, I attached an example with simple identity custom operator (which doesn't work).

sangrey commented 4 years ago

Was this ever resolved or was a work around ever found?

leezu commented 4 years ago

You can consider trying https://github.com/apache/incubator-mxnet/tree/83797400128d41910d87e957131f29fd466f4777/example/extensions/lib_custom_op instead of the Python-level custom operator. But I'm not sure if it has been tested with conditional operator yet.

samskalicky commented 4 years ago

Heres the full stack trace from the given example on v1.8 MX:

#0  0x00007fffdee8748d in std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > mxnet::op::custom::List<(CustomOpPropCallbacks)1>(nnvm::NodeAttrs const&) () from /home/ubuntu/18_fixes/python/mxnet/../../lib/libmxnet.so
#1  0x00007fffdee77bd8 in mxnet::op::custom::AttrParser(nnvm::NodeAttrs*) ()
   from /home/ubuntu/18_fixes/python/mxnet/../../lib/libmxnet.so
#2  0x00007fffe54d8a32 in nnvm::Symbol::CreateFunctor(nnvm::Op const*, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >) ()
   from /home/ubuntu/18_fixes/python/mxnet/../../lib/libmxnet.so
#3  0x00007fffe3c492d1 in MXSymbolCreateAtomicSymbol () from /home/ubuntu/18_fixes/python/mxnet/../../lib/libmxnet.so
#4  0x00007ffff6607ec0 in ffi_call_unix64 () from /home/ubuntu/anaconda3/lib/python3.6/lib-dynload/../../libffi.so.6
#5  0x00007ffff660787d in ffi_call () from /home/ubuntu/anaconda3/lib/python3.6/lib-dynload/../../libffi.so.6
#6  0x00007ffff681cdae in _ctypes_callproc ()
   from /home/ubuntu/anaconda3/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so
#7  0x00007ffff681d7e5 in PyCFuncPtr_call ()

Its failing in the MXNet Python Custom Op, but on the C++ side here: https://github.com/apache/incubator-mxnet/blob/v1.x/src/operator/custom/custom.cc#L106 after it calls back into Python here: https://github.com/apache/incubator-mxnet/blob/v1.x/python/mxnet/operator.py#L733

samskalicky commented 4 years ago

Simple fix seems to be adding **kwargs to the IdentityOPProp constructor like:

@mx.operator.register("identityop")
class IdentityOPProp(mx.operator.CustomOpProp):
    def __init__(self, **kwargs):
        super(IdentityOPProp, self).__init__(True)
        print('IdentityOPProp: %s' % kwargs)

Now it works and prints:

IdentityOPProp: {'__subgraph_name__': 'cond_else0'}

[[3.]]
<NDArray 1x1 @cpu(0)>

[[3.]]
<NDArray 1x1 @cpu(0)>

So the question is, how/why is __subgraph_name__ getting added to this op's attributes?

sangrey commented 4 years ago

That does look like the issue we're having. I'll try that workaround. Thanks!