apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.73k stars 6.81k forks source link

SGLD Optimizer in python cannot work due to wrong argument position #9605

Closed y1xiaoc closed 6 years ago

y1xiaoc commented 6 years ago

Note: Providing complete information in the most concise form is the best way to get help. This issue template serves as the checklist for essential information to most of the technical issues and bug reports. For non-technical issues and feature requests, feel free to present the information in what you believe is the best form.

For Q & A and discussion, please start a discussion thread at https://discuss.mxnet.io

Description

when calling SGLD optimizer's update method, an error occur due to the false argument position of calling ramdom.normal

Environment info (Required)

----------Python Info---------- Version : 3.6.2 Compiler : MSC v.1900 64 bit (AMD64) Build : ('default', 'Jul 20 2017 12:30:02') Arch : ('64bit', 'WindowsPE') ------------Pip Info----------- Version : 9.0.1 Directory : C:\Users\Xiao\Anaconda3\lib\site-packages\pip ----------MXNet Info----------- C:\Users\Xiao\Anaconda3\lib\site-packages\h5py__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type. from ._conv import register_converters as _register_converters Version : 1.0.0 Directory : C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet Hashtag not found. Not installed from pre-built package. ----------System Info---------- Platform : Windows-10-10.0.17074-SP0 system : Windows node : Xiao-pc release : 10 version : 10.0.17074 ----------Hardware Info---------- machine : AMD64 processor : Intel64 Family 6 Model 60 Stepping 3, GenuineIntel 错误: 描述 = 发生意外。 ----------Network Test---------- Setting timeout: 10 Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0060 sec, LOAD: 1.4936 sec. Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0937 sec, LOAD: 0.1105 sec. Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.0120 sec, LOAD: 0.3488 sec. Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0210 sec, LOAD: 0.4903 sec. Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0120 sec, LOAD: 0.0540 sec. Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0100 sec, LOAD: 0.0442 sec.

Package used (Python/R/Scala/Julia):

python

Error Message:

Traceback (most recent call last): File "d:/Research/deepLearning/mxnet/mx_model.py", line 180, in trainer.train(model) File "d:/Research/deepLearning/mxnet/mx_model.py", line 142, in train trainer.step(1) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\gluon\trainer.py", line 199, in step upd(i, grad, arr) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\optimizer.py", line 1163, in call self.optimizer.update_multi_precision(index, weight, grad, self.states[index]) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\optimizer.py", line 275, in update_multi_precision self.update(index, weight, grad, state) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\optimizer.py", line 653, in update weight.shape, weight.context) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\ndarray\random.py", line 152, in normal [loc, scale], shape, dtype, ctx, out, kwargs) File "C:\Users\Xiao\Anaconda3\lib\site-packages\mxnet\ndarray\random.py", line 47, in _random_helper return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs) File "", line 55, in _random_normal TypeError: data type not understood

Steps to reproduce

simply call SGLD's update method.

What have you tried to solve it?

change the line weight.shape, weight.context) (https://github.com/apache/incubator-mxnet/blob/9ae63d16b600b52d7ec6c83cdffe427a27583fee/python/mxnet/optimizer.py#L767)

into weight.shape, ctx=weight.context as the position of weight.context is belong to dtype

sxjscience commented 6 years ago

Thanks for reporting this! Would you like to submit a pull request to fix the problem? We can change the code to normal(0, math.sqrt(lr), shape=weight.shape, dtype=weight.dtype, ctx=weight.context)