apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

How to add BatchNorm to lstm bucketing? #2800

Open dianyancao opened 8 years ago

dianyancao commented 8 years ago

I try add BatchNorm to the lstm.py in 'example/rnn' by modifying following code in lstm function gates = i2h + h2h to

i2h_bn = mx.sym.BatchNorm(data=i2h,fix_gamma=False,name="t%d_l%d_i2h_bn" % (seqidx, layeridx))
h2h_bn = mx.sym.BatchNorm(data=h2h,fix_gamma=False,name="t%d_l%d_h2h_bn" % (seqidx, layeridx))
gates = i2h_bn + h2h_bn

Then I run the lstm_bucketing.py,got the following error:

Traceback (most recent call last):
  File "E:/NewCR/mxnet/mxnet/example/rnn/lstm_bucketing.py", line 94, in <module>
    batch_end_callback=batch_end_callback(batch_size, 40))
  File "E:\NewCR\python\mxnet\model.py", line 791, in fit
    sym_gen=self.sym_gen)
  File "E:\NewCR\python\mxnet\model.py", line 223, in _train_multi_device
    executor_manager.load_data_batch(data_batch)
  File "E:\NewCR\python\mxnet\executor_manager.py", line 388, in load_data_batch
    shared_group=self.execgrp)
  File "E:\NewCR\python\mxnet\executor_manager.py", line 235, in __init__
    for i in self.param_idx]
IndexError: list index out of range

How should I make BatchNorm work with bucketing? Regard for your reply.

piiswrong commented 8 years ago

@tqchen Can you make aux support variable input?

tqchen commented 8 years ago

I think this need to come with our next round of refactor. For temp solution, we can bind multiple aux to the same ndarray

dianyancao commented 8 years ago

I modify some code in function load_data_batch,_bind_exec in file 'python/mxnet/executor_manager.py' to get it works. The next is how to tie the BatchNorm parameters gamma,beta and auxiliary states mean,var to make use the same BatchNorm parameters at each time step? The following code does not work,I print out the BatchNorm parameters and states,but it is not to be tied https://github.com/dianyancao/mxnet/blob/master/example/rnn/lstm_bucketing.py

#data_train.default_bucket_key == 82
t0_l0_h2h_bn_beta AverageL2Norm: 0.000124244
t0_l0_h2h_bn_gamma AverageL2Norm: 0.0999984
t0_l0_i2h_bn_beta AverageL2Norm: 0.000124244
t0_l0_i2h_bn_gamma AverageL2Norm: 0.0999349
t0_l1_h2h_bn_beta AverageL2Norm: 0.00244821
t0_l1_h2h_bn_gamma AverageL2Norm: 0.0999984
t0_l1_i2h_bn_beta AverageL2Norm: 0.00244821
t0_l1_i2h_bn_gamma AverageL2Norm: 0.100002
...
t65_l0_h2h_bn_beta AverageL2Norm: 0.0
t65_l0_h2h_bn_gamma AverageL2Norm: 0.0999984
t65_l0_i2h_bn_beta AverageL2Norm: 0.0
t65_l0_i2h_bn_gamma AverageL2Norm: 0.0999984
t65_l1_h2h_bn_beta AverageL2Norm: 0.0
t65_l1_h2h_bn_gamma AverageL2Norm: 0.0999984
t65_l1_i2h_bn_beta AverageL2Norm: 0.0
t65_l1_i2h_bn_gamma AverageL2Norm: 0.0999984
...
t0_l0_h2h_bn_moving_mean AverageL2Norm: 5.99235e-10
t0_l0_h2h_bn_moving_var AverageL2Norm: 0.0147809
t0_l0_i2h_bn_moving_mean AverageL2Norm: 0.00927613
t0_l0_i2h_bn_moving_var AverageL2Norm: 0.0175071
t0_l1_h2h_bn_moving_mean AverageL2Norm: 1.55362e-09
t0_l1_h2h_bn_moving_var AverageL2Norm: 0.0147809
t0_l1_i2h_bn_moving_mean AverageL2Norm: 0.000209122
t0_l1_i2h_bn_moving_var AverageL2Norm: 0.0151166
...
t65_l0_h2h_bn_moving_mean AverageL2Norm: 0.0
t65_l0_h2h_bn_moving_var AverageL2Norm: 1.0
t65_l0_i2h_bn_moving_mean AverageL2Norm: 0.0
t65_l0_i2h_bn_moving_var AverageL2Norm: 1.0
t65_l1_h2h_bn_moving_mean AverageL2Norm: 0.0
t65_l1_h2h_bn_moving_var AverageL2Norm: 1.0
t65_l1_i2h_bn_moving_mean AverageL2Norm: 0.0
t65_l1_i2h_bn_moving_var AverageL2Norm: 1.0
dianyancao commented 8 years ago

I have exposed BatchNorm auxiliary states(moving_mean,moving_var) to parameters and add 'layer normalization':https://arxiv.org/abs/1607.06450 Now I can use BatchNorm in lstm bucketing.Thanks

sxjscience commented 8 years ago

@dianyancao Sounds great! You are more than welcome to PR an example if you have time.

dianyancao commented 8 years ago

@sxjscience,Is the operator test case done in '/tests/python/unittest/test_operator.py'? I run 'test_operator.py' got some TypeError: unary_ndarray_function() got an unexpected keyword argument 'begin'. The mxnet is just new cloned today.So are there something to be updated?

I have removed the auxiliary states in BatchNorm,and change them to arguments{ "data", "gamma", "beta", "moving_mean", "moving_var" }. Another thing is I add layer normalization when (ctx.istrain && param.use_global_stats) is true, But this is not tested,I'm not very confident that the code is correct.

sxjscience commented 8 years ago

@dianyancao If you have installed the package using python setup.py install instead of python setup.py develop, you may need to install it again.

dianyancao commented 8 years ago

@sxjscience@piiswrong@tqchen ,I am new to MXNet,I modify BatchNorm auxiliary states to parameters,it seems to be bad idea.Now I test layer normalization with use_global_stats==true in training phase use 'test_operator.py'. I don't know how to bind multiple BatchNorm auxiliary states and parameters to same ndarrays.So I use the bad idea above.Waiting for dmlc members to solve BatchNorm in Lstm...

dianyancao commented 8 years ago

@piiswrong @tqchen @sxjscience Is ok to allow same parameter variable names and auxiliary state variable names in many symbols,and this will bind same ndarrays to the variables with the same names?