Open apeforest opened 6 years ago
@mxnet-label-bot [Python, Bug]
@mxnet-label-bot [Operator]
I will work on this issue: https://issues.apache.org/jira/browse/MXNET-865
@apeforest What's the progress on this one ?
@piyushghai I don't have the bandwidth to work on this one now. Please label it [Call for contribution]
There are two scenarios here in which where
operator should pass.
condition
is of 1D array, the below code(snippet from control_flow_op.h in method WhereOpShape()
) checks for size equality which is correct given number of dimension is just 1.Line: 191
} else if ((*in_attrs)[0].ndim() == 1) {
CHECK_EQ((*in_attrs)[0].Size(), static_cast<size_t>(tshape[0]));
return true;
condition = mx.sym.Variable('condition')
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
where_sym = mx.sym.where(condition, x, y)
where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
y=mx.nd.array([[8,9],[10,11],[12,13]]),
condition=mx.nd.array([1,0,1])) # 1D array
condition = mx.sym.Variable('condition')
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
where_sym = mx.sym.where(condition, x, y)
where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
y=mx.nd.array([[8,9],[10,11],[12,13]]),
condition=mx.nd.array([1,0,1,1])) # Incorrect 1D array
The 1D array condition is working perfectly as expected in symbolic mode. It throws error if there is a dimension mismatch.
condition
should have the same shape as input x
. The below code checks for shape_assign
between condition
and x
where it internally checks for dimension compatibility, and it's returning False
(Which is expected behaviour if there is a dimension mismatch), but the test is not throwing any error(STRANGE). Line: 187
if ((*in_attrs)[0].ndim() == tshape.ndim()) {
if (!shape_assign(&tshape, (*in_attrs)[0])) return false; // This is returning false which is expected behavior when there is dimension mismatch between variable condition and variable x but the script is not asserting. However, it passes which is strange.
SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape);
return true;
In reality, it should throw error something like:
mxnet.base.MXNetError: Error in operator where0:
Shape inconsistent, Provided = [1,2], inferred shape=[3,2]
I will look into it further in detail. Please feel free to suggest if I am going in the wrong direction. Thanks!
Description
When the inputs to an operator is invalid, the InferShape in the operator returns false. This return value is not caught and treated properly in Symbolic mode, whereas the imperative mode would raise error.
Environment info (Required)
Package used (Python/R/Scala/Julia): Python
Minimum reproducible example