nengo / nengo-dl

Deep learning integration for Nengo
https://www.nengo.ai/nengo-dl
Other
88 stars 22 forks source link

load_params misbehaves with scale_firing_rates for some architectures #213

Open arvoelke opened 3 years ago

arvoelke commented 3 years ago

Dependencies:

This is a stripped-down version of https://www.nengo.ai/nengo-dl/examples/keras-to-snn.html with a do_bug flag added. This code evaluates the test accuracy using non-spiking activations with two different values of scale_firing_rates.

The do_bug flag simply changes the architecture from dense (True) to convolutional (False), where the latter is the one in the docs example. The issue appears to have something to do with load_weights not being consistent in how it will apply the params onto the model post-conversion and post-build, when a different scale_firing_rates is used between training and testing.

Since it works with a convolutional architecture (i.e., consistent test results independently of scale_firing_rates), and it is advertised to work in the documentation, I'd expect it to work with other architectures as well.

do_bug = False  # <-- change this and re-run

import nengo
import numpy as np
import tensorflow as tf

import nengo_dl

seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)

(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

# flatten images and add time dimension
train_images = train_images.reshape((train_images.shape[0], 1, -1))
train_labels = train_labels.reshape((train_labels.shape[0], 1, -1))
test_images = test_images.reshape((test_images.shape[0], 1, -1))
test_labels = test_labels.reshape((test_labels.shape[0], 1, -1))

if do_bug:
    inp = tf.keras.layers.Input(shape=(28, 28, 1))
    q_flat = tf.keras.layers.Flatten()(inp)
    q_dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)(q_flat)
    q_dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)(q_dense1)
    q_dense3 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(q_dense2)
    dense = tf.keras.layers.Dense(10)(q_dense3)

else:  # same as notebook example
    # input
    inp = tf.keras.Input(shape=(28, 28, 1))

    # convolutional layers
    conv0 = tf.keras.layers.Conv2D(
        filters=32,
        kernel_size=3,
        activation=tf.nn.relu,
    )(inp)

    conv1 = tf.keras.layers.Conv2D(
        filters=64,
        kernel_size=3,
        strides=2,
        activation=tf.nn.relu,
    )(conv0)

    # fully connected layer
    flatten = tf.keras.layers.Flatten()(conv1)
    dense = tf.keras.layers.Dense(units=10)(flatten)

model = tf.keras.Model(inputs=inp, outputs=dense)
model.summary()

converter = nengo_dl.Converter(model)

with nengo_dl.Simulator(converter.net, minibatch_size=200) as sim:
    # run training
    sim.compile(
        optimizer=tf.optimizers.Adam(0.001),
        loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.metrics.sparse_categorical_accuracy],
    )
    sim.fit(
        {converter.inputs[inp]: train_images},
        {converter.outputs[dense]: train_labels},
        validation_data=(
            {converter.inputs[inp]: test_images},
            {converter.outputs[dense]: test_labels},
        ),
        epochs=2,
    )

    # save the parameters to file
    sim.save_params("./keras_to_snn_params")

def run_network(
    activation,
    params_file="keras_to_snn_params",
    n_steps=30,
    scale_firing_rates=1,
    synapse=None,
    n_test=400,
):
    # convert the keras model to a nengo network
    nengo_converter = nengo_dl.Converter(
        model,
        swap_activations={tf.nn.relu: activation},
        scale_firing_rates=scale_firing_rates,
        synapse=synapse,
    )

    # get input/output objects
    nengo_input = nengo_converter.inputs[inp]
    nengo_output = nengo_converter.outputs[dense]

    # repeat inputs for some number of timesteps
    tiled_test_images = np.tile(test_images[:n_test], (1, n_steps, 1))

    # build network, load in trained weights, run inference on test images
    with nengo_dl.Simulator(
        nengo_converter.net, minibatch_size=10, progress_bar=False
    ) as nengo_sim:
        nengo_sim.load_params(params_file)
        data = nengo_sim.predict({nengo_input: tiled_test_images})

    # compute accuracy on test data, using output of network on
    # last timestep
    predictions = np.argmax(data[nengo_output][:, -1], axis=-1)
    accuracy = (predictions == test_labels[:n_test, 0, 0]).mean()
    print(f"Test accuracy: {100 * accuracy:.2f}%")

run_network(activation=nengo.RectifiedLinear(), scale_firing_rates=100)
run_network(activation=nengo.RectifiedLinear(), scale_firing_rates=0.01)

Issue discovered by @studywolf.

arvoelke commented 3 years ago

As a workaround, you can do the training in keras, save the parameters via model.save_weights, and then load them via model.load_weights before calling the converter. Here's the example from above modified accordingly. The do_bug flag now gives the same test result as scale_firing_rates is varied, regardless of architecture.

do_bug = False  # <-- no longer a bug here

import nengo
import numpy as np
import tensorflow as tf

import nengo_dl

seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)

(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

if do_bug:
    inp = tf.keras.layers.Input(shape=(28, 28, 1))
    q_flat = tf.keras.layers.Flatten()(inp)
    q_dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)(q_flat)
    q_dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)(q_dense1)
    q_dense3 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(q_dense2)
    dense = tf.keras.layers.Dense(10)(q_dense3)

else:  # same as notebook example
    # input
    inp = tf.keras.Input(shape=(28, 28, 1))

    # convolutional layers
    conv0 = tf.keras.layers.Conv2D(
        filters=32,
        kernel_size=3,
        activation=tf.nn.relu,
    )(inp)

    conv1 = tf.keras.layers.Conv2D(
        filters=64,
        kernel_size=3,
        strides=2,
        activation=tf.nn.relu,
    )(conv0)

    # fully connected layer
    flatten = tf.keras.layers.Flatten()(conv1)
    dense = tf.keras.layers.Dense(units=10)(flatten)

model = tf.keras.Model(inputs=inp, outputs=dense)
model.summary()

model.compile(
    optimizer=tf.optimizers.Adam(0.001),
    loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.metrics.sparse_categorical_accuracy],
)
model.fit(
    train_images,
    train_labels,
    validation_data=(test_images, test_labels),
    epochs=2,
)

# save the parameters to file
model.save_weights("./keras_to_snn_params")

# flatten images and add time dimension
test_images = test_images.reshape((test_images.shape[0], 1, -1))
test_labels = test_labels.reshape((test_labels.shape[0], 1, -1))

def run_network(
    activation,
    params_file="keras_to_snn_params",
    n_steps=30,
    scale_firing_rates=1,
    synapse=None,
    n_test=400,
):
    model.load_weights(params_file)

    # convert the keras model to a nengo network
    nengo_converter = nengo_dl.Converter(
        model,
        swap_activations={tf.nn.relu: activation},
        scale_firing_rates=scale_firing_rates,
        synapse=synapse,
    )

    # get input/output objects
    nengo_input = nengo_converter.inputs[inp]
    nengo_output = nengo_converter.outputs[dense]

    # repeat inputs for some number of timesteps
    tiled_test_images = np.tile(test_images[:n_test], (1, n_steps, 1))

    # build network, load in trained weights, run inference on test images
    with nengo_dl.Simulator(
        nengo_converter.net, minibatch_size=10, progress_bar=False
    ) as nengo_sim:
        data = nengo_sim.predict({nengo_input: tiled_test_images})

    # compute accuracy on test data, using output of network on
    # last timestep
    predictions = np.argmax(data[nengo_output][:, -1], axis=-1)
    accuracy = (predictions == test_labels[:n_test, 0, 0]).mean()
    print(f"Test accuracy: {100 * accuracy:.2f}%")

run_network(activation=nengo.RectifiedLinear(), scale_firing_rates=100)
run_network(activation=nengo.RectifiedLinear(), scale_firing_rates=0.01)