Closed hockman1 closed 3 years ago
It should be possible, using Quantization and normalization of the data, then encode it with spikes.
The discussion area is more suitable to this questions, since this is not a bug or unexpected behavioral.
Is there any examples for STDP training with iris dataset
Not yet, but I think this is feasible:
The first thing I'm thinking of is to use the unsupervised Self Organizing Map (SOM) model: https://github.com/BindsNET/bindsnet/blob/61fda794534bf3c85307f8f71a6124ff5c7c4b52/examples/mnist/SOM_LM-SNNs.py
-the input layer size is 4 -encode the 4 decimal inputs into spikes using Poisson (just like in the SOM source) -let the unsupervised STDP choose and arrange the most relevant neurons, given the input patterns -try with an output layer of size 400, it should be fast and accurate -let the supervised linear readout mechanism (from the SOM source) assign the output neurons (or neurons regions) to each of the 3 classes -use about 100 time steps (ms) for each run
The current script has a lot of monitoring tools for you to understand what's going on.
The biggest work is to switch from a 2D image (MNIST) to 4 parameters.
You should get a decent result with a very small model!
I try to modify the SOM_LM-SNNs.py, but I got 0 accuracy. Can you help me to find the error? here is my code
import argparse
import os
from time import time as t
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
from bindsnet.analysis.plotting import (
plot_assignments,
plot_input,
plot_performance,
plot_spikes,
plot_voltages,
plot_weights,
)
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder, poisson
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting
from bindsnet.models import IncreasingInhibitionNetwork
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_assignments, get_square_weights
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--n_neurons", type=int, default=400)
parser.add_argument("--n_epochs", type=int, default=1)
parser.add_argument("--n_test", type=int, default=10000)
parser.add_argument("--n_train", type=int, default=60000)
parser.add_argument("--n_workers", type=int, default=-1)
parser.add_argument("--theta_plus", type=float, default=0.05)
parser.add_argument("--time", type=int, default=100)
parser.add_argument("--dt", type=int, default=1.0)
parser.add_argument("--intensity", type=float, default=64)
parser.add_argument("--progress_interval", type=int, default=10)
parser.add_argument("--update_interval", type=int, default=250)
parser.add_argument("--update_inhibation_weights", type=int, default=500)
parser.add_argument("--plot_interval", type=int, default=250)
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=True, gpu=True)
args = parser.parse_args()
seed = args.seed
n_neurons = args.n_neurons
n_epochs = args.n_epochs
n_test = args.n_test
n_train = args.n_train
n_workers = args.n_workers
theta_plus = args.theta_plus
time = args.time
dt = args.dt
intensity = args.intensity
progress_interval = args.progress_interval
plot_interval = args.plot_interval
update_interval = args.update_interval
plot = args.plot
gpu = args.gpu
update_inhibation_weights = args.update_inhibation_weights
# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
device = "cpu"
if gpu:
gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
# Determines number of workers to use
if n_workers == -1:
n_workers = 0 # torch.cuda.is_available() * 4 * torch.cuda.device_count()
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
# Build network.
network = IncreasingInhibitionNetwork(
n_input=4,
n_neurons=n_neurons,
start_inhib=10,
max_inhib=-40.0,
theta_plus=0.05,
tc_theta_decay=1e7,
inpt_shape=(1, 4),
nu=(1e-4, 1e-2),
)
network.to(device)
# Load MNIST data.
# dataset = MNIST(
# PoissonEncoder(time=time, dt=dt),
# None,
# root=os.path.join("..", "..", "data", "MNIST"),
# download=True,
# transform=transforms.Compose(
# [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
# ),
# )
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from bindsnet.encoding import Encoder, NullEncoder
from torch.utils.data import Dataset
iris = load_iris()
X = iris['data']*10
y = iris['target']
names = iris['target_names']
feature_names = iris['feature_names']
# Split the data set into training and testing
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=2)
class CustomDataset(Dataset):
def __init__(self,images,labels,image_encoder=None,label_encoder=None) -> None:
super().__init__()
# Allow the passthrough of None, but change to NullEncoder
if image_encoder is None:
image_encoder = NullEncoder()
if label_encoder is None:
label_encoder = NullEncoder()
self.images = images
self.labels = labels
self.image_encoder = image_encoder
self.label_encoder = label_encoder
def __getitem__(self, ind: int):
image, label = self.images[ind],self.labels[ind]
# print(image)
# print(label)
image = torch.tensor(image).unsqueeze(0).to(torch.float32)
label = torch.tensor(label)
# print(image.shape)
output = {
"image": image,
"label": label,
"encoded_image": self.image_encoder(image),
"encoded_label": self.label_encoder(label),
}
return output
def __len__(self) -> int:
return self.images.shape[0]
dataset = CustomDataset(images=X_train,labels=y_train,image_encoder=PoissonEncoder(time=time, dt=dt))
# Record spikes during the simulation.
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)
# Neuron assignments and spike proportions.
n_classes = 3
assignments = -torch.ones(n_neurons, device=device)
proportions = torch.zeros((n_neurons, n_classes), device=device)
rates = torch.zeros((n_neurons, n_classes), device=device)
# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}
# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(
network.layers["Y"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(som_voltage_monitor, name="som_voltage")
# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
)
network.add_monitor(spikes[layer], name="%s_spikes" % layer)
voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(
network.layers[layer], state_vars=["v"], time=int(time / dt), device=device
)
network.add_monitor(voltages[layer], name="%s_voltages" % layer)
inpt_ims, inpt_axes = None, None
spike_ims, spike_axes = None, None
weights_im = None
assigns_im = None
perf_ax = None
voltage_axes, voltage_ims = None, None
save_weights_fn = "plots/weights/weights.png"
save_performance_fn = "plots/performance/performance.png"
save_assaiments_fn = "plots/assaiments/assaiments.png"
directorys = ["plots", "plots/weights", "plots/performance", "plots/assaiments"]
for directory in directorys:
if not os.path.exists(directory):
os.makedirs(directory)
# diagonal weights for increassing the inhibitiosn
weights_mask = (1 - torch.diag(torch.ones(n_neurons))).to(device)
# Train the network.
print("\nBegin training.\n")
start = t()
for epoch in range(n_epochs):
labels = []
if epoch % progress_interval == 0:
print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
start = t()
# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu
)
pbar = tqdm(total=n_train)
for step, batch in enumerate(dataloader):
if step == n_train:
break
# Get next input sample.
inputs = {
"X": batch["encoded_image"].view(int(time / dt), 1, 4).to(device)
}
if step > 0:
if step % update_inhibation_weights == 0:
if step % (update_inhibation_weights * 10) == 0:
network.Y_to_Y.w -= weights_mask * 50
else:
# Inhibit the connection even more
# network.Y_to_Y.w -= weights_mask * network.Y_to_Y.w.abs()*0.2
network.Y_to_Y.w -= weights_mask * 0.5
if step % update_interval == 0:
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels, device=device)
# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)
# Compute network accuracy according to available classification strategies.
accuracy["all"].append(
100
* torch.sum(label_tensor.long() == all_activity_pred).item()
/ len(label_tensor)
)
accuracy["proportion"].append(
100
* torch.sum(label_tensor.long() == proportion_pred).item()
/ len(label_tensor)
)
tqdm.write(
"\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
% (
accuracy["all"][-1],
np.mean(accuracy["all"]),
np.max(accuracy["all"]),
)
)
tqdm.write(
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
" (best)\n"
% (
accuracy["proportion"][-1],
np.mean(accuracy["proportion"]),
np.max(accuracy["proportion"]),
)
)
# Assign labels to excitatory layer neurons.
assignments, proportions, rates = assign_labels(
spikes=spike_record,
labels=label_tensor,
n_labels=n_classes,
rates=rates,
)
labels = []
labels.append(batch["label"])
temp_spikes = 0
factor = 1.2
for retry in range(5):
# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)
# Get spikes from the network
temp_spikes = spikes["Y"].get("s").squeeze()
if temp_spikes.sum().sum() < 2:
inputs["X"] *= (
poisson(
datum=factor * batch["image"].clamp(min=0),
dt=dt,
time=int(time / dt),
)
.to(device)
.view(int(time / dt), 1, 4)
)
factor *= factor
else:
break
# Get voltage recording.
exc_voltages = som_voltage_monitor.get("v")
# Add to spikes recording.
# spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
spike_record[step % update_interval].copy_(temp_spikes, non_blocking=True)
# Optionally plot various simulation information.
if plot and step % plot_interval == 0:
image = batch["image"].view(4)
inpt = inputs["X"].view(time, 4).sum(0).view(4)
input_exc_weights = network.connections[("X", "Y")].w
square_weights = get_square_weights(
input_exc_weights.view(4, n_neurons), n_sqrt, 2
)
square_assignments = get_square_assignments(assignments, n_sqrt)
spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
voltages = {"Y": exc_voltages}
# inpt_axes, inpt_ims = plot_input(
# image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims
# )
spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes)
[weights_im, save_weights_fn] = plot_weights(
square_weights, im=weights_im, save=save_weights_fn
)
assigns_im = plot_assignments(
square_assignments, im=assigns_im, save=save_assaiments_fn
)
perf_ax = plot_performance(accuracy, ax=perf_ax, save=save_performance_fn)
voltage_ims, voltage_axes = plot_voltages(
voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line"
)
#
plt.pause(1e-8)
network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Train progress: ")
pbar.update()
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
# Load MNIST data.
# test_dataset = MNIST(
# PoissonEncoder(time=time, dt=dt),
# None,
# root=os.path.join("..", "..", "data", "MNIST"),
# download=True,
# train=False,
# transform=transforms.Compose(
# [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
# ),
# )
test_dataset = CustomDataset(images=X_test,labels=y_test,image_encoder=PoissonEncoder(time=time, dt=dt))
# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}
# Record spikes during the simulation.
spike_record = torch.zeros(1, int(time / dt), n_neurons, device=device)
# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()
pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
if step >= n_test:
break
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 4)}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}
# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)
# Add to spikes recording.
spike_record[0] = spikes["Y"].get("s").squeeze()
# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"], device=device)
# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)
# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(
torch.sum(label_tensor.long() == proportion_pred).item()
)
network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Test progress: ")
pbar.update()
print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
X = iris['data']*10
is not enough to generate spikes, as reported by the neurons threshold graph.
I found that:
X = iris['data']*args.intensity
is better, especially when intensity is about 100.0
Then, the Iris dataset is very short (120 training samples). Therefore you have to adjust these:
parser.add_argument("--progress_interval", type=int, default=10)
parser.add_argument("--update_interval", type=int, default=250)
parser.add_argument("--update_inhibation_weights", type=int, default=500)
to account for the length of the dataset.
I'd suggest to try to set all of them to 120-1=119 for example, so that at every epoch, you adjust the SOM mechanism (increasing output neurons lateral inhibition), and report the best average accuracy from the last epoch.
If you don't succeed, with these ideas, try to monitor the 2D assignment map of the output neurons, and see if it is diverse enough. After some training, it should look like some smooth 2D map of the input labels (0, 1 or 2).
Also, don't hesitate to use many epochs, like 10 or more... because the dataset is so short.
I tried the parameters that you suggested, but I got very low accuracy, about 37%, can you give more suggestions, I am new to the SNN.
import argparse
import os
from time import time as t
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
from bindsnet.analysis.plotting import (
plot_assignments,
plot_input,
plot_performance,
plot_spikes,
plot_voltages,
plot_weights,
)
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder, poisson
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting
from bindsnet.models import IncreasingInhibitionNetwork
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_assignments, get_square_weights
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from bindsnet.encoding import Encoder, NullEncoder
from torch.utils.data import Dataset
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--n_neurons", type=int, default=13)
parser.add_argument("--n_epochs", type=int, default=20)
parser.add_argument("--n_test", type=int, default=30)
parser.add_argument("--n_train", type=int, default=120)
parser.add_argument("--n_workers", type=int, default=-1)
parser.add_argument("--theta_plus", type=float, default=0.05)
parser.add_argument("--time", type=int, default=100)
parser.add_argument("--dt", type=int, default=1.0)
parser.add_argument("--intensity", type=float, default=50)
parser.add_argument("--progress_interval", type=int, default=10)
parser.add_argument("--update_interval", type=int, default=10)
parser.add_argument("--update_inhibation_weights", type=int, default=10)
parser.add_argument("--plot_interval", type=int, default=500)
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=True, gpu=True)
args = parser.parse_args()
seed = args.seed
n_neurons = args.n_neurons
n_epochs = args.n_epochs
n_test = args.n_test
n_train = args.n_train
n_workers = args.n_workers
theta_plus = args.theta_plus
time = args.time
dt = args.dt
intensity = args.intensity
progress_interval = args.progress_interval
plot_interval = args.plot_interval
update_interval = args.update_interval
plot = args.plot
gpu = args.gpu
update_inhibation_weights = args.update_inhibation_weights
# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
device = "cpu"
if gpu:
gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
# Determines number of workers to use
if n_workers == -1:
n_workers = 0 # torch.cuda.is_available() * 4 * torch.cuda.device_count()
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
# Build network.
network = IncreasingInhibitionNetwork(
n_input=4,
n_neurons=n_neurons,
start_inhib=10,
max_inhib=-40.0,
theta_plus=0.05,
tc_theta_decay=1e7,
inpt_shape=(1, 4),
nu=(1e-4, 1e-2),
)
network.to(device)
iris = load_iris()
X = iris['data']*args.intensity
y = iris['target']
names = iris['target_names']
feature_names = iris['feature_names']
# Split the data set into training and testing
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=2)
class CustomDataset(Dataset):
def __init__(self,images,labels,image_encoder=None,label_encoder=None) -> None:
super().__init__()
# Allow the passthrough of None, but change to NullEncoder
if image_encoder is None:
image_encoder = NullEncoder()
if label_encoder is None:
label_encoder = NullEncoder()
self.images = images
self.labels = labels
self.image_encoder = image_encoder
self.label_encoder = label_encoder
def __getitem__(self, ind: int):
image, label = self.images[ind],self.labels[ind]
image = torch.tensor(image).unsqueeze(0).to(torch.float32)
label = torch.tensor(label)
output = {
"image": image,
"label": label,
"encoded_image": self.image_encoder(image),
"encoded_label": self.label_encoder(label),
}
return output
def __len__(self) -> int:
return self.images.shape[0]
dataset = CustomDataset(images=X_train,labels=y_train,image_encoder=PoissonEncoder(time=time, dt=dt))
# Record spikes during the simulation.
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)
# Neuron assignments and spike proportions.
n_classes = 3
assignments = -torch.ones(n_neurons, device=device)
proportions = torch.zeros((n_neurons, n_classes), device=device)
rates = torch.zeros((n_neurons, n_classes), device=device)
# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}
# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(
network.layers["Y"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(som_voltage_monitor, name="som_voltage")
# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
)
network.add_monitor(spikes[layer], name="%s_spikes" % layer)
voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(
network.layers[layer], state_vars=["v"], time=int(time / dt), device=device
)
network.add_monitor(voltages[layer], name="%s_voltages" % layer)
inpt_ims, inpt_axes = None, None
spike_ims, spike_axes = None, None
weights_im = None
assigns_im = None
perf_ax = None
voltage_axes, voltage_ims = None, None
save_weights_fn = "plots/weights/weights.png"
save_performance_fn = "plots/performance/performance.png"
save_assaiments_fn = "plots/assaiments/assaiments.png"
directorys = ["plots", "plots/weights", "plots/performance", "plots/assaiments"]
for directory in directorys:
if not os.path.exists(directory):
os.makedirs(directory)
# diagonal weights for increassing the inhibitiosn
weights_mask = (1 - torch.diag(torch.ones(n_neurons))).to(device)
# Train the network.
print("\nBegin training.\n")
start = t()
for epoch in range(n_epochs):
labels = []
if epoch % progress_interval == 0:
print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
start = t()
# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu
)
pbar = tqdm(total=n_train)
for step, batch in enumerate(dataloader):
if step == n_train:
break
# Get next input sample.
inputs = {
"X": batch["encoded_image"].view(int(time / dt), 1, 4).to(device)
}
if step > 0:
if step % update_inhibation_weights == 0:
if step % (update_inhibation_weights * 10) == 0:
network.Y_to_Y.w -= weights_mask * 50
else:
# Inhibit the connection even more
# network.Y_to_Y.w -= weights_mask * network.Y_to_Y.w.abs()*0.2
network.Y_to_Y.w -= weights_mask * 0.5
if step % update_interval == 0:
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels, device=device)
# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)
# Compute network accuracy according to available classification strategies.
accuracy["all"].append(
100
* torch.sum(label_tensor.long() == all_activity_pred).item()
/ len(label_tensor)
)
accuracy["proportion"].append(
100
* torch.sum(label_tensor.long() == proportion_pred).item()
/ len(label_tensor)
)
tqdm.write(
"\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
% (
accuracy["all"][-1],
np.mean(accuracy["all"]),
np.max(accuracy["all"]),
)
)
tqdm.write(
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
" (best)\n"
% (
accuracy["proportion"][-1],
np.mean(accuracy["proportion"]),
np.max(accuracy["proportion"]),
)
)
# Assign labels to excitatory layer neurons.
assignments, proportions, rates = assign_labels(
spikes=spike_record,
labels=label_tensor,
n_labels=n_classes,
rates=rates,
)
labels = []
labels.append(batch["label"])
temp_spikes = 0
factor = 1.2
for retry in range(5):
# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)
# Get spikes from the network
temp_spikes = spikes["Y"].get("s").squeeze()
if temp_spikes.sum().sum() < 2:
inputs["X"] *= (
poisson(
datum=factor * batch["image"].clamp(min=0),
dt=dt,
time=int(time / dt),
)
.to(device)
.view(int(time / dt), 1, 4)
)
factor *= factor
else:
break
# Get voltage recording.
exc_voltages = som_voltage_monitor.get("v")
# Add to spikes recording.
# spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
spike_record[step % update_interval].copy_(temp_spikes, non_blocking=True)
# Optionally plot various simulation information.
if plot and step % plot_interval == 0:
image = batch["image"].view(4)
inpt = inputs["X"].view(time, 4).sum(0).view(4)
input_exc_weights = network.connections[("X", "Y")].w
square_weights = get_square_weights(
input_exc_weights.view(4, n_neurons), n_sqrt, 2
)
square_assignments = get_square_assignments(assignments, n_sqrt)
spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
voltages = {"Y": exc_voltages}
inpt_axes, inpt_ims = plot_input(
image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims
)
spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes)
[weights_im, save_weights_fn] = plot_weights(
square_weights, im=weights_im, save=save_weights_fn
)
assigns_im = plot_assignments(
square_assignments, im=assigns_im, save=save_assaiments_fn
)
perf_ax = plot_performance(accuracy, ax=perf_ax, save=save_performance_fn)
voltage_ims, voltage_axes = plot_voltages(
voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line"
)
#
plt.pause(1e-8)
network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Train progress: ")
pbar.update()
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
test_dataset = CustomDataset(images=X_test,labels=y_test,image_encoder=PoissonEncoder(time=time, dt=dt))
# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}
# Record spikes during the simulation.
spike_record = torch.zeros(1, int(time / dt), n_neurons, device=device)
# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()
pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
if step >= n_test:
break
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 4)}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}
# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)
# Add to spikes recording.
spike_record[0] = spikes["Y"].get("s").squeeze()
# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"], device=device)
# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)
# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(
torch.sum(label_tensor.long() == proportion_pred).item()
)
network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Test progress: ")
pbar.update()
print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
Can you post the assignment map of the output neurons, and the spikes raster plots of the network? That will give us indication on how the the activity in network.
here is my results, can you give more advice, thanks!
From the weights matrix values, second image, its seems that the weights are saturated. To solve this, adjust the normalization value in the IncreasingInhibitionNetwork
object from the default norm = 78.4
to something lower.
HI is it possible to introduce a simple STDP training for iris dataset classification? Thanks!