nengo / nengo-dl

Deep learning integration for Nengo
https://www.nengo.ai/nengo-dl
Other
88 stars 22 forks source link

BatchNormalization layer produces LOW accuracy #234

Open kebitmatf opened 1 year ago

kebitmatf commented 1 year ago

Dears,

I'm testing the BatchNormalization layer with MNIST dataset in the Tutorial example. The BatchNormalization layer decreases the accuracy of the validation dataset from 98% to 72%. The below is the sample code, just simular as in the Example, I just added BatchNormalization layer

**with nengo.Network(seed=0) as net:

set some default parameters for the neurons that will make

# the training progress more smoothly
net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
net.config[nengo.Connection].synapse = None
neuron_type = nengo.LIF(amplitude=0.01)

# this is an optimization to improve the training speed,
# since we won't require stateful behaviour in this example
nengo_dl.configure_settings(stateful=False)

# the input node that will be used to feed in input images
inp = nengo.Node(np.zeros(28 * 28))

# add the first convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=32, kernel_size=3))(
    inp, shape_in=(28, 28, 1)
)
x = nengo_dl.Layer(neuron_type)(x)
x = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(x)

x1_p = nengo.Probe(x)
x1_p_fil = nengo.Probe(x,synapse=0.1)

# add the second convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=64, strides=2, kernel_size=3))(
    x, shape_in=(26, 26, 32)
)
x = nengo_dl.Layer(neuron_type)(x)
x = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(x)

# add the third convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=128, strides=2, kernel_size=3))(
    x, shape_in=(12, 12, 64)
)
x = nengo_dl.Layer(neuron_type)(x)
x = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(x)

# linear readout
out = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(x)

# we'll create two different output probes, one with a filter
# (for when we're simulating the network over time and
# accumulating spikes), and one without (for when we're
# training the network using a rate-based approximation)
out_p = nengo.Probe(out, label="out_p")
out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")**

What's wrong here? Why the BatchNormalization layer drops the accuracy so much? Thank you!