genn-team / ml_genn

A library for deep learning with Spiking Neural Networks (SNN).
https://ml-genn.readthedocs.io
GNU Lesser General Public License v2.1
24 stars 7 forks source link

Accuracy output bug for validation_split in the train method (EventProp) #124

Open Saraworld93 opened 4 hours ago

Saraworld93 commented 4 hours ago
Screenshot 2024-11-15 at 15 00 12

As shown in the screenshot example run, training accuracy is 0.9925 which translates to 99.25%, test accuracy is 0.8397 = 83.97% but the validation accuracy seems to be off by a factor of 100, or perhaps isn't correct at all giving 0.00816 = 0.816% Please let me know if you need me to provide any further information or code. Thank you very much in advance.

neworderofjamie commented 3 hours ago

What model are you running?

Saraworld93 commented 3 hours ago

I'm running Eventprop in the example, here's a snippet:

# Network definition
serialiser = Numpy("nmnist_checkpoints")
network = SequentialNetwork(default_params)
with network:
    # Populations
    input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * max_spikes), NUM_INPUT)
    hidden = Layer(Dense(Normal(mean=0.078, sd=0.045)),
                         LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                            tau_refrac=None),
                         NUM_HIDDEN, Exponential(5.0))
    output = Layer(Dense(Normal(mean=0.2, sd=0.37)),
                         LeakyIntegrate(tau_mem=20.0, readout="avg_var"),
                         NUM_OUTPUT, Exponential(5.0))

# Compilation
max_example_timesteps = int(np.ceil(calc_latest_spike_time(train_spikes)))
compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
                         losses="sparse_categorical_crossentropy",
                         optimiser=Adam(1e-2), reg_lambda_upper=1e-9,
                             reg_lambda_lower = 1e-9, batch_size=BATCH_SIZE)

compiled_net = compiler.compile(network)

# Metrics storage
train_accuracies = []
validation_accuracies = []

# Initialise early stopping
early_stopping = EarlyStopping(patience=20)

# Training
with compiled_net:
    start_time = perf_counter()
    callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
    validation_callbacks = ["batch_progress_bar", Checkpoint(serialiser)]

    for epoch in range(11):
        train_metrics, validation_metrics, _, _ = compiled_net.train({input: train_spikes},
                       {output: train_labels},
                       num_epochs=1, shuffle=True,
                       callbacks=callbacks,
                       validation_callbacks=validation_callbacks,
                       validation_split=0.1)
neworderofjamie commented 3 hours ago

Is there not a print line at the end somewhere like:

print(f"Accuracy = {100 * metrics[output].result}%")
Saraworld93 commented 2 hours ago

Yes there is for train, but I can tell that if I print the validation accuracy*100 it will still be off by a factor of a 100, I have added another print statement for validation just now just to verify, screenshot below:

Screenshot 2024-11-15 at 16 48 23

The last accuracy is the test set one.

Extended code snippet:

# Network definition
serialiser = Numpy("nmnist_checkpoints")
network = SequentialNetwork(default_params)
with network:
    # Populations
    input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * max_spikes), NUM_INPUT)
    hidden = Layer(Dense(Normal(mean=0.078, sd=0.045)),
                         LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                            tau_refrac=None),
                         NUM_HIDDEN, Exponential(5.0))
    output = Layer(Dense(Normal(mean=0.2, sd=0.37)),
                         LeakyIntegrate(tau_mem=20.0, readout="avg_var"),
                         NUM_OUTPUT, Exponential(5.0))

# Compilation
max_example_timesteps = int(np.ceil(calc_latest_spike_time(train_spikes)))
compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
                         losses="sparse_categorical_crossentropy",
                         optimiser=Adam(1e-2), reg_lambda_upper=1e-9,
                             reg_lambda_lower = 1e-9, batch_size=BATCH_SIZE)

compiled_net = compiler.compile(network)

# Metrics storage
train_accuracies = []
validation_accuracies = []

# Initialise early stopping
early_stopping = EarlyStopping(patience=20)

# Training
with compiled_net:
    start_time = perf_counter()
    callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
    validation_callbacks = ["batch_progress_bar", Checkpoint(serialiser)]

    for epoch in range(11):
        train_metrics, validation_metrics, _, _ = compiled_net.train({input: train_spikes},
                       {output: train_labels},
                       num_epochs=1, shuffle=True,
                       callbacks=callbacks,
                       validation_callbacks=validation_callbacks,
                       validation_split=0.1)

        # recording for each epoch
        train_accuracy = train_metrics[output].result # Adjust based on your metrics structure
        train_accuracies.append(train_accuracy)
        validation_accuracy = validation_metrics[output].result
        validation_accuracies.append(validation_accuracy)
        print(f"Epoch {epoch + 1}, Training accuracy: {train_accuracy}, Validation accuracy: {validation_accuracy}")
        # Check if we should stop training
        if early_stopping.should_stop(validation_accuracy):
            print(f"Early stopping triggered after {epoch + 1} epochs")
            break

    # Plotting the learning curve
    plt.figure(figsize=(10, 5))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Learning Curve')
    plt.legend()
    plt.ylim(0, 1)
    plt.show()
    end_time = perf_counter()
    print(f"Train Accuracy = {100 * train_metrics[output].result}%")
    print(f"Validation Accuracy = {100 * validation_metrics[output].result}%")
    print(f"Time = {end_time - start_time}s")

best_epoch = validation_accuracies.index(max(validation_accuracies))+1

# Preprocess the testing data
test_spikes = []
test_labels = []

for i in range(len(nmnist_test)):
    events, label = nmnist_test[i]
    processed_test_spikes = preprocess_tonic_spikes(events, nmnist_test.ordering, nmnist_test.sensor_size)
    test_spikes.append(processed_test_spikes)
    test_labels.append(label)

# Evaluate
network.load((10,), serialiser)  # Load weights from the best epoch
compiler = InferenceCompiler(evaluate_timesteps=max_example_timesteps,
                             batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

with compiled_net:
    compiled_net.evaluate({input: test_spikes}, {output: test_labels})

    # Evaluate model on numpy dataset
    start_time = perf_counter()
    metrics, _ = compiled_net.evaluate({input: test_spikes},
                                       {output: test_labels})
    # Get predicted outputs
    # predictions = compiled_net.get_readout(list(test_labels))
    # predicted_classes = np.argmax(predictions, axis=1)

    end_time = perf_counter()
    print(f"Accuracy = {100 * metrics[output].result}%")
    print(f"Time = {end_time - start_time}s")