cooijmanstim / recurrent-batch-normalization

64 stars 26 forks source link

Padding Affecting Batch Norm #2

Open NickShahML opened 8 years ago

NickShahML commented 8 years ago

Hey Tim,

The batch norm LSTM paper published has pretty stellar results. I've followed the keras issue thread where you stated that gamma and beta are shared throughout all timesteps, yet the actual statistics should be kept for each timestep separately. Thanks for clarifying this.

Unfortunately, I implemented the batch norm LSTM (gamma = 0.1) in tensorflow, and it seems to not perform as well as a regular LSTM. I'm applying this to sequences that have padding in them.

My question is this: Is it possible that the padding is throwing everything off? Towards the end of the sequence, padding probably disrupts the batch mean and variance Laurents suggested https://arxiv.org/pdf/1510.01378v1.pdf

Did you try padding for any of your experiments? If so, how did you bridge that gap? I'm thinking towards the end of sequence, the mean and variance should be kept from timestep 10, when clearly no padding is occurring.

Thanks!

cooijmanstim commented 8 years ago

Yes, this is something we've noticed as well. We're experimenting with two possibilities:

  1. Pad with repetitions of the data rather than zeros
  2. Normalize the input x sequence-wise as in the paper by César Laurent et al. that you link to.

We started out doing variant 1 but I now believe variant 2 is the better choice.

Also, the validation curves we report all use batch statistics, and we don't compute the population statistics until the very end of training. We make a final pass over the training set to estimate the population statistics exactly (as opposed to by moving average) and use that to perform our test evaluation.

Thanks for asking, I hope that helps!

NickShahML commented 8 years ago

Thanks Tim for your feedback.

Padding with repetitions of data is an interesting idea. However, I do feel that would throw off the network's performance as it would have to learn to "throw out" the padded material and not consider it.

For option 2, do you mean normalizing just the input and NOT the hidden state into the LSTM?

Because we are using an LSTM, we can not know what the future hidden states until we have calculated the entire previous timestep. Normally you want to apply batch norm after you have matrix multiplied the input by the weights.

For simplicity's sake, I feel that it would be easier to normalize just the input before it is multiplied by the weight matrix if you're going to normalize both by batch and time.

The third option would to be take timestep 10's mean and variance and use that to compute the rest of the timesteps. In this way, you're using mean and variances where there was no padding exposure AND you maintain a different set of statistics at the beginning timesteps. As your paper shows, its important to keep separate statistics at the beginning timesteps. Thoughts on this?

Edit: I tried the third option described above and it slightly helped things but its very clear that the network is not learning as it should.

I recognize that this may be a stupid question. But when you say normalize sequence-wise input x -- do you mean just normalizing the input before any weight matrix is multiplied to it? From a programming perspective, normalizing sequence wise after a weight matrix is multiplied to it is incredibly difficult.

cooijmanstim commented 8 years ago

Padding with repetitions of data is an interesting idea. However, I do feel that would throw off the network's performance as it would have to learn to "throw out" the padded material and not consider it.

You would still be using a mask to make sure you don't run on the padding, right? The padding would be used only for the purpose of estimating the statistics.

For option 2, do you mean normalizing just the input and NOT the hidden state into the LSTM?

I mean normalizing only the input W_x x_t "outside" the LSTM, the W_h h_{t-1} term would still be normalized as in our paper.

For simplicity's sake, I feel that it would be easier to normalize just the input before it is multiplied by the weight matrix if you're going to normalize both by batch and time.

You can do this and it would help to some extent, but as you say it's better to normalize after the weight matrix. The idea is that the distribution of W_x x_t is determined by the entire weight matrix, and if you normalize it then it's determined only by the gamma and beta vectors.

The third option would to be take timestep 10's mean and variance and use that to compute the rest of the timesteps. In this way, you're using mean and variances where there was no padding exposure AND you maintain a different set of statistics at the beginning timesteps. As your paper shows, its important to keep separate statistics at the beginning timesteps. Thoughts on this?

That would work as well, but whether it's time step 10 or some other number might depend on your task. If you don't mind tuning it then this is a good solution.

I recognize that this may be a stupid question. But when you say normalize sequence-wise input x -- do you mean just normalizing the input before any weight matrix is multiplied to it? From a programming perspective, normalizing sequence wise after a weight matrix is multiplied to it is incredibly difficult.

You'll definitely want to normalize after the weight matrix. I'm not sure what difficulties you're thinking of; if you're working with Theano, you'd do something like this:

embedding = T.dot(x, W)
mean = (embedding * mask[:, :, None]).sum(axis=[0, 1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True)
variance = ((embedding * mask[:, :, None] - mean)**2).sum(axis=[0,1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True)
embedding = beta + gamma * (embedding - mean) / T.sqrt(variance + epsilon)

With axes 0 and 1 being batch and time in any order.

NickShahML commented 8 years ago

Tim, Thanks for your extensive reply. I really appreciate your time and feedback.

You would still be using a mask to make sure you don't run on the padding, right? The padding would be used only for the purpose of estimating the statistics.

Yes, I totally forgot about masking, and that would solve the issue raised earlier. You said that you started it with this idea, but you don't think its most optimal. Why do you feel this way? It seems like it should work.

Your combo of normalizing frame-wise the hidden state, and time-wise + frame-wise for the input seems the most logical to me. When you normalize the input sequence-wise, would you include or exclude padded inputs when you compute the mean and variance? From the code you provided, it looks like you would keep the padded frames?

Normalizing the input from the embedding layer is very easy as you described above. However, for layers 2 and 3, the process would need to be repeated which would take some work to implement. Definitely possible.

That would work as well, but whether it's time step 10 or some other number might depend on your task. If you don't mind tuning it then this is a good solution.

You're right -- you would have to estimate the max time step where padding isn't there.

I will test this option out and report back here with the results in case others come across this thread. I will only apply this continued average only to BN(Wx). I will not apply the continued average to the hidden state.

Will post back later with results.

EDIT: I have found that if I just batch normalize the hidden state input, it improves the network! If I apply BN to the tanh(new_c) term it seems to hurt it. Apply BN to W_x input term also seems to hurt.

Will try normalizing input by time-wise and see if this makes a difference and report back.

cooijmanstim commented 8 years ago

You said that you started it with this idea, but you don't think its most optimal. Why do you feel this way? It seems like it should work.

The problem I see is that the timestep-wise normalization destroys the dynamics of the input data. E.g. if your input is a one-dimensional signal such as an audio waveform, the normalization will amplify the quiet parts and attenuate the loud parts. The model won't know which is which anymore, and will easily confuse noise for signal.

More generally, timestepwise estimation works for stationary input signals. If the distribution is not stationary (e.g. there's loud parts and quiet parts), a mean/variance estimate based on such a narrow temporal window is a bad estimate of the global mean/variance.

When you normalize the input sequence-wise, would you include or exclude padded inputs when you compute the mean and variance? From the code you provided, it looks like you would keep the padded frames?

I notice a bug in the code I provided, the variance computation should go

variance = ((embedding - mean)**2 * mask[:, :, None]).sum(axis=[0,1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True)

That is, the multiplication by the mask should be moved outside the squared difference. I multiply by the mask and divide by the number of ones in the mask, so the padded elements do not contribute to the estimate.

I'm very curious to here more about your findings. What kind of data are you working with?

NickShahML commented 8 years ago

Would be happy to help. I run 5 separate Titan X's/980TI's -- so I try to rapidly test as much as possible. Models usually have two or three layers of 512 units. I am primarily working with English text data tokenized into words (usually 120 timesteps).

More generally, timestepwise estimation works for stationary input signals. If the distribution is not stationary (e.g. there's loud parts and quiet parts), a mean/variance estimate based on such a narrow temporal window is a bad estimate of the global mean/variance.

If we are indeed going to normalize sequence wise and batch wise, then the words that are "really loud" will be somewhat dampened. This would be skewed even more if you used a small batch size.

Do you think one potential solution is to simply always apply a running mean and variance to each timestep? Instead of normalizing by the specific batch's stats, it would be better to normalize by a running mean and variance. Of course, you could keep a separate running mean and variance for each timestep. Once you pass timestep 10, you apply the same running mean and var throughout.

Findings: I have found that if I normalize the input on layers 2+, it hurts the net, but not as drastically as to when I normalize inputs on layer 1.

Also noted is that you can raise the learning rate even if you just normalize hidden state input. Using Adam with 0.005 LR (which is pretty high) with a learning rate schedule.

I have found that when I apply batch norm to attention, it does hurt it which makes me think that there's something crucially wrong with my implementation. Will comment later when I have more findings.

cooijmanstim commented 8 years ago

If we are indeed going to normalize sequence wise and batch wise, then the words that are "really loud" will be somewhat dampened. This would be skewed even more if you used a small batch size.

I think if your words are one-hot encoded you should be fine (though if your vocabulary is large you may need a larger batch size or do this sequence-wise normalization which effectively increases your sample size). In audio data on the other hand, the input is typically real-valued and its absolute value varies a lot and this variability is highly informative so you don't want to lose it. I suspect this is a part of why batch normalization doesn't seem to help on speech (recurrently or otherwise).

Do you think one potential solution is to simply always apply a running mean and variance to each timestep?

The problem with that is that you can't backprop through the mean and variance, which is crucial. Batch normalization just doesn't seem to work if you don't do this. The accepted explanation for this is that the gradient should take into account the effect of the parameter update on the statistics, or optimization may go around in circles. There may be more to it.

I will be busy preparing our NIPS submission and moving so I may be less reponsive in the next week, but I do appreciate the discussion! :-)

NickShahML commented 8 years ago

I will be busy preparing our NIPS submission and moving so I may be less reponsive in the next week, but I do appreciate the discussion! :-)

No Problem, just message back whenever convenient. Appreciate your thoughts.

I think if your words are one-hot encoded you should be fine (though if your vocabulary is large you may need a larger batch size or do this sequence-wise normalization which effectively increases your sample size).

I usually train an embedding layer with the model as it performs better than using word2vec or glove. Usually have a vocab size of 40k.

The problem with that is that you can't backprop through the mean and variance, which is crucial. Batch normalization just doesn't seem to work if you don't do this. The accepted explanation for this is that the gradient should take into account the effect of the parameter update on the statistics, or optimization may go around in circles. There may be more to it.

Did not know this -- so thank you. It would explain so many findings I have had lately. I tried applying the running mean and variance and things did not improve at all.

I think I'm going to try to apply batch norming to associative lstm http://arxiv.org/abs/1602.03032 and see what happens.

Findings:

I tried sequence-wise AND batch-wise norming for input and it did help some, but not significantly. Was sure to exclude padded frames. Really, the main benefit I've seen is from batch-norming the hidden state. I tried stacking 4 or 5 layers with batch norming the input and unfortunately it did not help.

zhengwy888 commented 8 years ago

@LeavesBreathe Hello I am implementing a simple batch normalized LSTM in Tensorflow as well. could you explain when you batch normalized the hidden state, are you normalizing just the 'h', not the 'Wh'? and are they normalized timestep-wise, meaning for each layer and timestep combo, there is one mean and var? do you share the gamma and beta across layers? The other problem I encountered was batch normalization slows down the training a lot. wondering if you are seeing the same issue?

NickShahML commented 8 years ago

Bn doesn't really slow down training for me - -maybe 10% slowdown for step times.

Didn't mean to be confusing but I do normally do bn(Wh). I normalize it at each timestep for each layer separately. I share one gamma per layer.

Let us know if you get any improvements with BN(Wx)!

Pinlong-Zhao commented 8 years ago

@cooijmanstim @LeavesBreathe Hello,I don't know how I can get the datasets in this experiments. Could you tell me where I can find it? Or ,Could you send it to my E-mial 394523651@qq.com. Thank you very much.

OverLordGoldDragon commented 4 years ago

Nice discussion, thanks for the insights


@cooijmanstim

The problem I see is that the timestep-wise normalization destroys the dynamics of the input data

"timestep-wise" = w.r.t. (i.e. collapsing) samples and channels dimensions? If so, sense enough, but if channels are untouched, I wouldn't find standardizing w.r.t. samples to be unfair game as BN works with distribution estimates anyway. A single pass over the dataset with frozen weights is an excellent idea, by the way - should be applicable to batch-norm also.

Also, correct me if I'm wrong, but the answer to my opened Issue is here:

The problem with that is that you can't backprop through the mean and variance, which is crucial

and it does make sense; so then batch statistics are computed per-batch.