dmlc / gluon-nlp

NLP made easy
https://nlp.gluon.ai/
Apache License 2.0
2.55k stars 538 forks source link

Use official MXNet batchify to implement the batchify functions #1440

Open sxjscience opened 3 years ago

sxjscience commented 3 years ago

Description

We should use the official mxnet batchify functions to implement our own batchify functions. However, since we'd like to later support other frameworks, we should still keep our own batchify.py. We can change it to call MXNet implementations.

MXNet batchify: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/data/batchify.py GluonNLP batchify: https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/data/batchify.py

shenfei commented 3 years ago

GluonNLP’s data.batchify.Stack behaves different from mxnet.gluon.data.batchify.Stack.

Do I use it in a wrong way, or is it a bug of mxnet implementation?

[ins] In [1]: a = [1, 2, 3, 4]

[ins] In [2]: b = [4, 5, 6, 8]

[ins] In [3]: c = [8, 9, 1, 2]

[ins] In [4]: import gluonnlp.data.batchify as bf

[ins] In [5]: import mxnet.gluon.data.batchify as mxbf

[ins] In [6]: bf.Stack()([a, b, c])
[14:36:53] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
Out[6]:

[[1 2 3 4]
 [4 5 6 8]
 [8 9 1 2]]
<NDArray 3x4 @cpu(0)>

[ins] In [7]: mxbf.Stack()([a, b, c])
Out[7]:
[
 [1 4 8]
 <NDArray 3 @cpu(0)>,

 [2 5 9]
 <NDArray 3 @cpu(0)>,

 [3 6 1]
 <NDArray 3 @cpu(0)>,

 [4 8 2]
 <NDArray 3 @cpu(0)>]

BTW, the result of mxnet’s batchify.Stack is against its docstring.

szha commented 3 years ago

cc @zhreshold

sxjscience commented 3 years ago

It seems to be a bug in the MXNet batchify function and I can confirm that:

import mxnet as mx
import mxnet.gluon.data.batchify as mxbf
mx.npx.set_np()
a = [1, 2, 3, 4]
b = [4, 5, 6, 8]
c = [8, 9, 1, 2]
print(mxbf.Stack()([a, b, c]))
print(mxbf.Stack()([mx.np.array(a), mx.np.array(b), mx.np.array(c)]))

Output:

[array([1, 4, 8], dtype=int64), array([2, 5, 9], dtype=int64), array([3, 6, 1], dtype=int64), array([4, 8, 2], dtype=int64)]
[[1. 2. 3. 4.]
 [4. 5. 6. 8.]
 [8. 9. 1. 2.]]

Thus, we should fix the bug in MXNet first before switching to use the official implementation.