Open whamza15 opened 5 years ago
Hey, this is the MXNet Label Bot. Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it. Here are my recommended labels: Feature
Looks like a possible bug to me. I'm labelling it so that the MXNet community can help resolve it.
@mxnet-label-bot Add [Bug, Gluon]
@whamza15 This is not an issue of not using all variables in hybrid_forward as the following test works
import mxnet.gluon as gl
import mxnet as mx
class EmbeddingBlock(gl.HybridBlock):
def __init__(self, num_toks, dim, **kwargs):
super(EmbeddingBlock, self).__init__(**kwargs)
self.emb = gl.nn.Embedding(num_toks, dim)
def hybrid_forward(self, F, x, valid_length):
# NOTE valid_length is not used
return self.emb(x)
net = EmbeddingBlock(10, 100)
net.initialize()
net.hybridize()
x1 = mx.nd.array(range(8)).reshape(2,-1)
vl1 = mx.nd.array([3,2])
x2 = mx.nd.array(range(8)).reshape(2,-1)
vl2 = mx.nd.array([3,2])
net(x1, vl1)
print(net.collect_params())
EDIT: The above test works because deferred initialization is not used for embedding layers. For layers using deferred initialization like nn.dense
the issue exists as can be verified using the following:
class Net(gl.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
self.dense = gl.nn.Dense(3, flatten=False)
def hybrid_forward(self, F, x, v1):
return self.dense(x)
net = Net()
net.initialize()
net.hybridize()
x = mx.nd.array(range(8)).reshape(2,-1)
v1 = mx.nd.array([3,2])
net(x, v1)
Error Message:
/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py:540: UserWarning: The 1-th input to HybridBlock is not used by any computation. Is this intended?
out = self.forward(*args)
infer_shape error. Arguments:
data0: (2, 4)
data1: (2,)
Traceback (most recent call last):
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 803, in _call_cached_op
for is_arg, i in self._cached_op_args]
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 803, in <listcomp>
for is_arg, i in self._cached_op_args]
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/parameter.py", line 494, in data
return self._check_and_get(self._data, ctx)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/parameter.py", line 208, in _check_and_get
"num_features, etc., for network layers."%(self.name))
mxnet.gluon.parameter.DeferredInitializationError: Parameter 'dense0_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 789, in _deferred_infer_shape
self.infer_shape(*args)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 862, in infer_shape
self._infer_attrs('infer_shape', 'shape', *args)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 851, in _infer_attrs
**{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
File "/anaconda3/lib/python3.7/site-packages/mxnet/symbol/symbol.py", line 996, in infer_shape
res = self._infer_shape_impl(False, *args, **kwargs)
File "/anaconda3/lib/python3.7/site-packages/mxnet/symbol/symbol.py", line 1126, in _infer_shape_impl
ctypes.byref(complete)))
File "/anaconda3/lib/python3.7/site-packages/mxnet/base.py", line 252, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [14:53:40] src/c_api/c_api_symbolic.cc:494: InferShapeKeyword argument name data1 not found.
Candidate arguments:
[0]data0
[1]dense0_weight
[2]dense0_bias
Stack trace returned 5 entries:
[bt] (0) 0 libmxnet.so 0x000000011164e390 std::__1::__tree<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::__map_value_compare<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::less<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, true>, std::__1::allocator<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*> > >::destroy(std::__1::__tree_node<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, void*>*) + 2736
[bt] (1) 1 libmxnet.so 0x000000011164e13f std::__1::__tree<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::__map_value_compare<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::less<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, true>, std::__1::allocator<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*> > >::destroy(std::__1::__tree_node<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, void*>*) + 2143
[bt] (2) 2 libmxnet.so 0x0000000112c4a85e MXSymbolInferShape + 9582
[bt] (3) 3 libmxnet.so 0x0000000112c48b82 MXSymbolInferShape + 2194
[bt] (4) 4 libffi.6.dylib 0x000000010a0b1884 ffi_call_unix64 + 76
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test_gl1.py", line 28, in <module>
net(x, v1)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 540, in __call__
out = self.forward(*args)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 907, in forward
return self._call_cached_op(x, *args)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 805, in _call_cached_op
self._deferred_infer_shape(*args)
File "/anaconda3/lib/python3.7/site-packages/mxnet/gluon/block.py", line 793, in _deferred_infer_shape
raise ValueError(error_msg)
ValueError: Deferred initialization failed because shape cannot be inferred. [14:53:40] src/c_api/c_api_symbolic.cc:494: InferShapeKeyword argument name data1 not found.
Candidate arguments:
[0]data0
[1]dense0_weight
[2]dense0_bias
Stack trace returned 5 entries:
[bt] (0) 0 libmxnet.so 0x000000011164e390 std::__1::__tree<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::__map_value_compare<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::less<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, true>, std::__1::allocator<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*> > >::destroy(std::__1::__tree_node<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, void*>*) + 2736
[bt] (1) 1 libmxnet.so 0x000000011164e13f std::__1::__tree<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::__map_value_compare<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, std::__1::less<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, true>, std::__1::allocator<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*> > >::destroy(std::__1::__tree_node<std::__1::__value_type<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, mxnet::NDArrayFunctionReg*>, void*>*) + 2143
[bt] (2) 2 libmxnet.so 0x0000000112c4a85e MXSymbolInferShape + 9582
[bt] (3) 3 libmxnet.so 0x0000000112c48b82 MXSymbolInferShape + 2194
[bt] (4) 4 libffi.6.dylib 0x000000010a0b1884 ffi_call_unix64 + 76
I am trying to figure out if this is actually a bug and if there is a possible workaround for this usecase.
@sandeep-krishnamurthy @safrooze Could you please have a look?
Possibly related to #13967
It seems like this is expected behavior, @eric-haibin-lin could you have a look and confirm?
@whamza15 Since the error pops up due to deferred initialization, you can avoid it by specifying the input shape when creating the layers. Here is the full example:
import mxnet.gluon as gl
import mxnet as mx
class EmbeddingBlock(gl.HybridBlock):
def __init__(self, num_toks, dim, **kwargs):
super(EmbeddingBlock, self).__init__(**kwargs)
self.emb = gl.nn.Embedding(num_toks, dim)
def hybrid_forward(self, F, x, valid_length):
# NOTE valid_length is not used
return self.emb(x)
class Net(gl.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
self.dense = gl.nn.Dense(3, in_units=160, flatten=False)
self.e1 = EmbeddingBlock(10,100)
self.e2 = EmbeddingBlock(20,60)
def hybrid_forward(self, F, x1, vl1, x2, vl2):
o = F.concat(self.e1(x1,vl1), self.e2(x2,vl2), dim=-1)
return self.dense(o)
net = Net()
net.initialize()
net.hybridize()
x1 = mx.nd.array(range(8)).reshape(2,-1)
vl1 = mx.nd.array([3,2])
x2 = mx.nd.array(range(8)).reshape(2,-1)
vl2 = mx.nd.array([3,2])
net(x1, vl1, x2, vl2)
@mxnet-label-bot add [pending requester info]
@whamza15 Did these suggestions help you ?
@whamza15 Can you please close the issue if it has been resolved for you ?
Please feel free to re-open if closed in error.
Sorry, I did not get a chance to follow up on this. I can try what you described @abhinavs95. However, not using deferred initialization is going to be a bit of a set back in our toolkit that relies so much on that. Is there a possibility this can be solved and still rely on deferred initialization?
I just want to add that if I use valid_length
in the EmbeddingBlock, it works fine even with deferred initialization.
@whamza15 does it work if you pass []
as the value for valid_length?
@eric-haibin-lin I am not sure I understand the question. valid_length
always has value. It is just that this block does not use it. The reason we have this setup is that our toolkit allow people configure blocks (as complex as they want) without having to changing the input. Some blocks may choose to consume valid_length
(like complex encoders) while others may choose not to (like simple embedding block).
We have a temporary workaround in https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/model/transformer.py#L420-L501 but this bug should definitely fixed in MXNet
There is a similar problem when there are unused parameters. For example, you can have a model like this:
class Test(mx.gluon.nn.HybridBlock):
def __init__(self, mode, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mode = mode
with self.name_scope():
self.d1 = mx.gluon.nn.Dense(2)
self.d2 = mx.gluon.nn.Dense(3)
def hybrid_forward(self, F, x, *args, **kwargs):
o1 = self.d1(x)
o2 = self.d2(x)
if self.mode:
return o1 # output path o2 is not used
else:
return o1, o2
Currently, this model will not hybridize successfully, when mode == True
, because the weights in the o2
path are "unused".
Having unused parameters is useful since you might want your pretrain/finetune/evaluation networks to behave differently, but be compatible for .save_parameters
and .load_parameters
without allow_missing
and ignore_extra
.
I think this issue could be fixed without changing the inner workings too much by adding a F.nodiscard(o2)
operator. It would be a no-op in nd
mode and would somehow mark the output as a required computation during sym
mode. Not sure, how feasible something like that is.
My current workaround is something like
return F.broadcast_add(o1, F.sum(0.0 * o2)) # output path o2 is not used
which is both really ugly and potentially inefficient, since it forces the unneeded computation.
If the F.nodiscard
option is too hard to implement, something like
o1 = F.depends_on(o1, o2)
could also work. It would basically be the same as F.broadcast_add(o1, F.sum(0.0 * o2))
but without any computations.
cc @leezu
Any progress on this?
@whamza15 this will be taken into account in the MXNet 2.0 roadmap item 4.3, Gluon block enhancement, that @leezu is driving.
Description
Not using variables in hybrid_forward() causes deferred initialization to fail. There is no requirement that one should use ALL passed input. I am not sure why it failed to infer the input shape for the dense layer. It works fine without hybridize of course. The reason we are passing input data to blocks without using them is because some subclasses uses them and we would like to unify the interface so calling blocks do not have to be aware of what type of blocks they are calling. We cannot use
__call__()
orforward()
since these blocks will be hybridized and served from C++.Environment info (Required)
Error Message:
Minimum reproducible example
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
The only solutions that works is to use the unused variables in the graph in a redundant way.