casperkaae / parmesan

Variational and semi-supervised neural network toppings for Lasagne
Other
208 stars 31 forks source link

BatchNormalization #15

Closed botev closed 8 years ago

botev commented 8 years ago

It would be nice if the BatchNormalizationLayer, rather than supporting a "single_pass" would support "collect" where it will collect the variables in minibatches. If the dataset is big enough the "single_pass" could fail. Alternative would be something like this:

if collect:
                # This will collect the dataset statistics on minibatches
                # However to this accurately we will need a new extra variable
                # for E[x^2] and use its average instead of the std
                running_ex2 = theano.shared(np.zeros_like(self.std.get_value()))
                t = theano.shared(1)
                ex2 = input.sqr().mean(self.axes, keepdims=True)
                mean_update = ((t-1) / t) * running_mean + mean / t
                ex2_update = ((t-1) / t) * running_ex2 + ex2 / t
                std_update = (ex2_update - mean_update.sqr() + self.epsilon).sqrt()
                # Set the default updates
                running_mean.default_update = mean_update
                running_ex2.default_update = ex2_update
                running_std.default_update = std_update
                t.default_update = t + 1
                # and include them in the graph so their default updates will be
                # applied (although the expressions will be optimized away later)
                mean += 0 * running_mean + 0 * t
                std += 0 * running_std + 0 * running_ex2
            else:
                # During training instead we use a geometric moving average
                running_mean.default_update = ((1 - self.alpha) * running_mean +
                                               self.alpha * mean)
                running_std.default_update = ((1 - self.alpha) * running_std +
                                              self.alpha * std)
                # and include them in the graph so their default updates will be
                # applied (although the expressions will be optimized away later)
                mean += 0 * running_mean
                std += 0 * running_std
skaae commented 8 years ago

It would be nice if the BatchNormalizationLayer, rather than supporting a "single_pass" would support "collect" where it will collect the variables in minibatches.

I think it does. If you set alpha to e.g. 0.5 ? This is untested, but should work?

botev commented 8 years ago

That would take geometric average over the minibatches, e.g. for 3 minibatches:

mean = 0.125*m_1  + 0.25*m_2 + 0.5*m_3 

where you want in the end is

mean = 0.333 * m_1 + 0.333 * m_2 + 0.333 * m_3

Or am I missunderstanding what the point of the "single_pass" is?

skaae commented 8 years ago

Single pass assumes that you pass the entire dataset through the network in a single batch. That way you collect the correct statistics for all batch normalization layers.

It's probably correct that alpha!='single_pass' uses a geometric average. Its copied from @f0k 's implementation.

botev commented 8 years ago

Aha, I get it. Well then what I suggest is just convinience for collecting it after training is done! Thanks for the clarification.