apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Tracking mxnet.numpy operator issues for 1.6.0 release #16559

Closed reminisce closed 4 years ago

reminisce commented 4 years ago

Issues from running D2L numpy2 branch

  1. chapter_preliminaries/probability.md
TypeError: no implementation found for 'numpy.broadcast_arrays' on types that implement __array_function__: [<class 'numpy.ndarray'>, <class 'mxnet.numpy.ndarray'>]
  1. chapter_preliminaries/probability.md

    MXNetError: [09:16:23] src/operator/numpy/np_true_divide.cc:43: Check failed: lhs_dtype == rhs_dtype (7 vs. 0) : true_divide currently only supports same dtype for dividend and divisor
  2. chapter_multilayer-perceptrons/dropout.md

    mask = np.random.uniform(0, 1, X.shape) > drop_prob
    return mask * X / (1.0-drop_prob) 
    Fails on the second line there, should be the problem with boolean multiplying with float
  3. chapter_deep-learning-computation/parameters.md

    data[:] = np.random.uniform(-10, 10, data.shape)
    data *= np.abs(data) >= 5
    Fails on the 2nd line, still the multiplication between boolean & float
  4. chapter_recurrent-neural-networks/seq2seq.md

    MXNetError: [10:46:26] src/operator/numpy/np_true_divide.cc:43: Check failed: lhs_dtype == rhs_dtype (0 vs. 6) : true_divide currently only supports same dtype for dividend and divisor<ipython-input-10-23a855fab898> in train_s2s_ch8(model, data_iter, lr, num_epochs, ctx)      24             metric.add(l.sum(), num_tokens)      25         if epoch % 10 == 0: ---> 26             animator.add(epoch, (metric[0]/metric[1],))      27     print('loss %.3f, %d tokens/sec on %s ' % (      28         metric[0]/metric[1], metric[1]/timer.stop(), ctx))
    --
  5. All notebooks with train_s2s_ch8

    MXNetError: [10:46:26] src/operator/numpy/np_true_divide.cc:43: Check failed: lhs_dtype == rhs_dtype (0 vs. 6) : true_divide currently only supports same dtype for dividend and divisor
    <ipython-input-10-23a855fab898> in train_s2s_ch8(model, data_iter, lr, num_epochs, ctx)      24             metric.add(l.sum(), num_tokens)      25         if epoch % 10 == 0: ---> 26             animator.add(epoch, (metric[0]/metric[1],))      27     print('loss %.3f, %d tokens/sec on %s ' % (      28         metric[0]/metric[1], metric[1]/timer.stop(), ctx))
  6. chapter_optimization/optimization-intro.md

    ValueError: mxnet.numpy operator `<function column_stack at 0x7ff910066e18>` has not been registered in the _NUMPY_ARRAY_FUNCTION_LIST. Please make sure you are using NumPy >= 1.17.0 and the operator implementation is compatible with NumPy. Then add the operator name to the list.
  7. chapter_optimization/convexity.md

    Same as above
  8. chapter_natural-language-processing/sentiment-analysis-rnn.md

    ---------------------------------------------------------------------------MXNetError                                Traceback (most recent call last)<ipython-input-7-a7e697bf18e7> in <module>*      2* trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})*      3* loss = gluon.loss.SoftmaxCrossEntropyLoss()----> 4 d2l.train_ch12(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)~/d2l-numpy/d2l/d2l.py in train_ch12(net, train_iter, test_iter, loss, trainer, num_epochs, ctx_list, split_f)*   1054*             timer.start()*   1055*             l, acc = train_batch_ch12(-> 1056                 net, features, labels, loss, trainer, ctx_list, split_f)*   1057*             metric.add(l, acc, labels.shape[0], labels.size)*   1058*             timer.stop()~/d2l-numpy/d2l/d2l.py in train_batch_ch12(net, features, labels, loss, trainer, ctx_list, split_f)*   1037*         l.backward()*   1038*     trainer.step(features.shape[0])-> 1039     train_loss_sum = sum([float(l.sum()) for l in ls])*   1040*     train_acc_sum = sum(d2l.accuracy(py, y) for py, y in zip(pys, ys))*   1041*     return train_loss_sum, train_acc_sum
    ~/d2l-numpy/d2l/d2l.py in <listcomp>(.0)*   1037*         l.backward()*   1038*     trainer.step(features.shape[0])-> 1039     train_loss_sum = sum([float(l.sum()) for l in ls])*   1040*     train_acc_sum = sum(d2l.accuracy(py, y) for py, y in zip(pys, ys))*   1041*     return train_loss_sum, train_acc_sum
    ~/mxnet_master/python/mxnet/numpy/multiarray.py in __float__(self)*    791*         if num_elements != 1:*    792*             raise TypeError('only size-1 arrays can be converted to Python scalars')--> 793         return float(self.item())*    794**    795*     def __int__(self):~/mxnet_master/python/mxnet/numpy/multiarray.py in item(self, *args)*    830*         """*    831*         # TODO(junwu): no need to call asnumpy() on the whole array.--> 832         return self.asnumpy().item(*args)*    833**    834*     @property
    ~/mxnet_master/python/mxnet/ndarray/ndarray.py in asnumpy(self)*   2517*             self.handle,*   2518*             data.ctypes.data_as(ctypes.c_void_p),-> 2519             ctypes.c_size_t(data.size)))*   2520*         return data*   2521*~/mxnet_master/python/mxnet/base.py in check_call(ret)*    252*     """*    253*     if ret != 0:--> 254         raise MXNetError(py_str(_LIB.MXGetLastError()))*    255**    256*MXNetError: [11:23:59] src/operator/./rnn-inl.h:1505: Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading: CUDA: invalid resource handle
    Stack trace:
    [bt] (0) /home/ubuntu/mxnet_master/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x43) [0x7fe81ce57ff3]
    [bt] (1) /home/ubuntu/mxnet_master/python/mxnet/../../lib/libmxnet.so(mxnet::op::RNNOp<mshadow::gpu, float, float>::Backward(mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x129a) [0x7fe82357b35a]
    [bt] (2) /home/ubuntu/mxnet_master/python/mxnet/../../lib/libmxnet.so(void mxnet::op::RNNStatefulGradCompute<mshadow::gpu>(mxnet::OpStatePtr const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x146e) [0x7fe8235b508e]

    Note: Caused by #16391 (https://github.com/apache/incubator-mxnet/pull/16391), verified fix by reverting the PR. Also pinged DickJC123 for a fix

  9. chapter_recommender-systems/autorec.md

    ---------------------------------------------------------------------------MXNetError                                Traceback (most recent call last)<ipython-input-4-cc01b0d0d033> in <module>*      2* # Load the MovieLens 100K dataset*      3* df, num_users, num_items = d2l.read_data_ml100k()----> 4 train_data, test_data = d2l.split_data_ml100k(df, num_users, num_items)*      5* _, _, _, train_inter_mat = d2l.load_data_ml100k(train_data, num_users,*      6*                                                 num_items)
    ~/d2l-numpy/d2l/d2l.py in split_data_ml100k(data, num_users, num_items, split_mode, test_ratio)*   1432*     else:*   1433*         mask = [True if x == 1 else False for x in np.random.uniform(-> 1434             0, 1, (len(data))) < 1 - test_ratio]*   1435*         neg_mask = [not x for x in mask]*   1436*         train_data, test_data = data[mask], data[neg_mask]~/d2l-numpy/d2l/d2l.py in <listcomp>(.0)*   1431*         test_data = pd.DataFrame(test_data)*   1432*     else:-> 1433         mask = [True if x == 1 else False for x in np.random.uniform(*   1434*             0, 1, (len(data))) < 1 - test_ratio]*   1435*         neg_mask = [not x for x in mask]~/mxnet_master/python/mxnet/numpy/multiarray.py in __bool__(self)*    781*             return False*    782*         elif num_elements == 1:--> 783             return bool(self.item())*    784*         else:*    785*             raise ValueError("The truth value of an ndarray with multiple elements is ambiguous.")~/mxnet_master/python/mxnet/numpy/multiarray.py in item(self, *args)*    830*         """*    831*         # TODO(junwu): no need to call asnumpy() on the whole array.--> 832         return self.asnumpy().item(*args)*    833**    834*     @property
    ~/mxnet_master/python/mxnet/ndarray/ndarray.py in asnumpy(self)*   2517*             self.handle,*   2518*             data.ctypes.data_as(ctypes.c_void_p),-> 2519             ctypes.c_size_t(data.size)))*   2520*         return data*   2521*~/mxnet_master/python/mxnet/base.py in check_call(ret)*    252*     """*    253*     if ret != 0:--> 254         raise MXNetError(py_str(_LIB.MXGetLastError()))*    255**    256*MXNetError: [11:27:58] src/operator/numpy/../tensor/elemwise_binary_scalar_op.h:264: Unknown type enum 7

    Note: Verified fix by changing that type switch at src/operator/tensor/elemwise_binary_scalar_op.h:264 to the "WITH_BOOL" version

Missing Variables and Functions

  1. np.newaxis
ptrendx commented 4 years ago

@reminisce Is there any update to the status of those issues?

reminisce commented 4 years ago

@ptrendx All issues should have been fixed by @haojin2. Thanks.

ptrendx commented 4 years ago

Ok, if those issues are fixed could you close this then?