apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.76k stars 6.8k forks source link

MKLDNN doesn't support 6 dimensions #14766

Closed awsbillz closed 5 years ago

awsbillz commented 5 years ago

Hi MXNet,

We are running into this issue where the MKLDNN implementation of MXNet tensors doesn't support 5+ dimensions. Is there are specific reason for this? We use Reshape and Transpose operators that manipulates layer outputs into 6D tensors.

https://github.com/apache/incubator-mxnet/blob/master/src/ndarray/ndarray.cc#L475

File "/home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 1980, in asnumpy 
ctypes.c_size_t(data.size))) 
File "/home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/base.py", line 252, in check_call 
raise MXNetError(py_str(_LIB.MXGetLastError())) 
mxnet.base.MXNetError: [00:30:20] src/ndarray/ndarray.cc:481: MKLDNN doesn't support 6 dimensions 

We are currently running on MXNet 1.4.0 with MKLDNN on C5 instances.

Thanks,

Bill

mxnet-label-bot commented 5 years ago

Hey, this is the MXNet Label Bot. Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it. Here are my recommended labels: Feature

yuxihu commented 5 years ago

@TaoLv @pengzhao-intel Would you mind taking a look at this issue? Thanks.

pengzhao-intel commented 5 years ago

Thanks, @awsbillz @yuxihu, it sounds very useful and I will back to you soon for this request :)

TaoLv commented 5 years ago

@awsbillz Could you provide a reproducer for the crash?

awsbillz commented 5 years ago

The minimal reproducible code is below

"""
A combination of 1) 8 channel on Conv2D and 2) hybridized net seem to cause the crash
"""

import mxnet as mx
from mxnet.gluon import nn
import mxnet.ndarray as F
from mxnet.gluon.block import HybridBlock

class Reshape2D(HybridBlock):
    def __init__(self, factor):
        super(Reshape2D, self).__init__()
        self._factors = (int(factor),) * 2

    def hybrid_forward(self, F, x):
        f1, f2 = self._factors
                                                      # (N, f1*f2*C, H, W)
        x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0))  # (N, C, f1*f2, H, W)
        x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0))    # (N, C, f1, f2, H, W)
        x = F.transpose(x, (0, 1, 4, 2, 5, 3))        # (N, C, H, f1, W, f2)
        x = F.reshape(x, (0, 0, -3, -3))              # (N, C, H*f1, W*f2)
        return x

class Net(HybridBlock):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            self.conv1 = nn.Conv2D(8, kernel_size=5) # <- 1)
            self.reshape2D = Reshape2D(2)

    def hybrid_forward(self, F, x):
        x = self.conv1(x)
        x = self.reshape2D(x)
        return x

net = Net()
net.initialize(mx.init.Xavier(), ctx=mx.cpu())

net.hybridize() # <- 2)

data = mx.nd.random_normal(shape=(1, 3, 600, 600))  # dummy data
output = net(data)
a = output.asnumpy()
print(a)

We experimented turning on/off hybridization and there does appear to be a difference.

Without hybridization. The code runs correctly

root@6f1ae84d7a5a:/crowdcounting# MKLDNN_VERBOSE=1 python test.py
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_Ohwi8o,num:1,8x3x5x5,0.00317383
mkldnn_verbose,exec,convolution,jit:avx2,forward_inference,fsrc:nchw fwei:Ohwi8o fbia:x fdst:nChw8c,alg:convolution_direct,mb1_g1ic3oc8_ih600oh596kh5sh1dh0ph0_iw600ow596kw5sw1dw0pw0,2.40796
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_Ohwi8o,num:1,8x3x5x5,0.00219727
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x8x596x596,1.37305
[[[[ 0.63157237 -0.13570794 -0.13340439 ...  1.1704854   0.4845633
     1.5047258 ]
   [ 1.2083259   0.37998998 -1.1894836  ... -0.48679283  0.65234876
     0.07821631]
   [-0.5278215  -1.3345666  -0.61814046 ... -0.78595865 -0.6419563
     0.01790139]
   ...
   [ 0.9739233   0.26053143  0.10977446 ... -1.0813308  -0.79446477
     0.59918064]
   [ 0.6313175  -1.1710652   0.14362349 ...  0.14761467  1.2973223
     0.07053347]
   [ 0.595051    0.53777224 -0.7015837  ... -0.9797133   0.46997756
    -0.06855441]]

  [[ 0.08321354 -0.43839744  0.25719148 ... -0.48135263  0.2331308
    -0.20161374]
   [-0.13869795 -1.0949888   0.67054224 ... -0.7056279   0.4550503
    -0.39589623]
   [ 0.45498207 -0.2215124  -0.7991557  ... -1.1719621  -0.6238807
     0.7984329 ]
   ...
   [-0.6333465  -0.48290926  0.6820876  ...  1.4864701   0.6147204
    -0.41900736]
   [ 1.0353158  -0.06823038 -0.11530954 ... -0.06267852 -0.56047356
    -0.70380664]
   [-0.51138985 -0.12400899 -1.7057749  ... -1.1799597  -0.4319848
    -0.0879695 ]]]]

With hybridization, however, we see the stacktrace below:

root@6f1ae84d7a5a:/crowdcounting# MKLDNN_VERBOSE=1 python test.py
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_Ohwi8o,num:1,8x3x5x5,0.00292969
mkldnn_verbose,exec,convolution,jit:avx2,forward_inference,fsrc:nchw fwei:Ohwi8o fbia:x fdst:nChw8c,alg:convolution_direct,mb1_g1ic3oc8_ih600oh596kh5sh1dh0ph0_iw600ow596kw5sw1dw0pw0,2.36499
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x8x596x596,1.36206
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw8c,num:1,1x8x596x596,0.417969
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x8x596x596,1.21411
mkldnn_verbose,exec,reorder,simple:any,undef,in:f32_goihw out:f32_goihw,num:1,1x2x4x596x596,1.04883
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_Ohwi8o,num:1,8x3x5x5,0.000976562
Traceback (most recent call last):
  File "test.py", line 45, in <module>
    a = output.asnumpy()
  File "/usr/local/lib/python2.7/dist-packages/mxnet/ndarray/ndarray.py", line 1980, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/usr/local/lib/python2.7/dist-packages/mxnet/base.py", line 252, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [18:53:06] src/ndarray/ndarray.cc:481: MKLDNN doesn't support 6 dimensions

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x21d6d4) [0x7f00719ec6d4]
[bt] (1) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x21dab1) [0x7f00719ecab1]
[bt] (2) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(mxnet::NDArray::Chunk::SetMKLMem(nnvm::TShape const&, int)+0xe7) [0x7f007450a317]
[bt] (3) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(mxnet::NDArray::GetMKLDNNData() const+0x5e) [0x7f007451482e]
[bt] (4) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(void mxnet::op::CastStorageComputeImpl<mshadow::cpu>(mxnet::OpContext const&, mxnet::NDArray const&, mxnet::NDArray const&)+0x267) [0x7f007239ebd7]
[bt] (5) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(mxnet::common::CastNonDefaultStorage(std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&, mxnet::OpContext const&, bool)+0x24f) [0x7f007434207f]
[bt] (6) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const+0x372) [0x7f00743d2ee2]
[bt] (7) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x2b5ba6d) [0x7f007432aa6d]
[bt] (8) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x2b5ba57) [0x7f007432aa57]
[bt] (9) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x2b5ba57) [0x7f007432aa57]
pengzhao-intel commented 5 years ago

Thanks @awsbillz , we will take a look :)

TaoLv commented 5 years ago

Thank you for sharing the reproducer, @awsbillz . Is it possible for you to try a recent night build of MXNet? I think the issue has been fixed on master branch. To install night build through pip: pip install mxnet-mkl --pre

awsbillz commented 5 years ago

The nightly build does appear to have fixed our issue. Could you point us to the commit that fixed the issue just for reference?

When/will this be in the next release?

TaoLv commented 5 years ago

Should be this one: https://github.com/apache/incubator-mxnet/commit/c4a619c6808df82e65081bf55e815c7a7ec55566

pengzhao-intel commented 5 years ago

The change will go into release 1.5. Closing now and feel free to reopen if there are any other issues.