apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.79k stars 6.79k forks source link

Computation fails with oneDNN error #21199

Open matteosal opened 1 year ago

matteosal commented 1 year ago
import mxnet as mx

json = "{\"nodes\":[{\"op\":\"null\",\"name\":\".Inputs.Input\",\"inputs\":[]\
},{\"op\":\"Reshape\",\"name\":\".Nodes.1$0\",\"attrs\":{\"shape\":\"(\
-3, -2)\"},\"inputs\":[[0,0,0]]},{\"op\":\"null\",\"name\":\".Nodes.1.\
Parameters.Net.Arrays.Weights\",\"inputs\":[]},{\"op\":\"null\",\"\
name\":\".Nodes.1.Parameters.Net.Arrays.Biases\",\"inputs\":[]},{\"op\
\":\"Convolution\",\"name\":\".Nodes.1.Parameters.Net\",\"attrs\":{\"\
cudnn_off\":\"0\",\"dilate\":\"(1, 1)\",\"kernel\":\"(1, \
1)\",\"layout\":\"None\",\"no_bias\":\"False\",\"num_filter\":\"200\",\
\"num_group\":\"1\",\"pad\":\"(0, 0)\",\"stride\":\"(1, \
1)\"},\"inputs\":[[1,0,0],[2,0,0],[3,0,0]]},{\"op\":\"reshape_like\",\
\"name\":\".Nodes.1$1\",\"attrs\":{\"lhs_begin\":\"0\",\"lhs_end\":\"\
1\",\"rhs_begin\":\"0\",\"rhs_end\":\"2\"},\"inputs\":[[4,0,0],[0,0,0]\
]},{\"op\":\"Reshape\",\"name\":\".Nodes.2$0\",\"attrs\":{\"shape\":\"\
(0, 2, 200)\"},\"inputs\":[[5,0,0]]},{\"op\":\"_copy\",\"name\":\".\
Outputs.Output\",\"inputs\":[[6,0,0]]}],\"arg_nodes\":[0,2,3],\"heads\
\":[[7,0,0]]}"

sym = mx.symbol.fromjson(json)
op = mx.ndarray.CachedOp(sym)

args = [
    # .Inputs.Input
    mx.np.random.uniform(size=[1, 2, 200, 1, 1], ctx=mx.cpu()),
    # .Nodes.1.Parameters.Net.Arrays.Weights
    mx.np.random.uniform(size=[200, 200, 1, 1], ctx=mx.cpu()),
    # .Nodes.1.Parameters.Net.Arrays.Biases
    mx.np.random.uniform(size=[200], ctx=mx.cpu())  
]

output = op(*args)
print(output)

The above script fails with the following error:

mxnet.base.MXNetError: MXNetError: could not create a primitive descriptor for a reorder primitive

The error appears to be generated from dnnl.hpp. There is a single convolution operator in this symbol which has 200 input channels and 200 output channels. Lowering either of those to 199 does not trigger the error.