NeuromorphicProcessorProject / snn_toolbox

Toolbox for converting analog to spiking neural networks (ANN to SNN), and running them in a spiking neuron simulator.
MIT License
360 stars 104 forks source link

Error when converting Pytorch model with an nn.MaxPool2d layer #103

Closed JakobVokac closed 2 years ago

JakobVokac commented 3 years ago

I tried converting the CORnet-S (https://github.com/dicarlolab/CORnet/blob/master/cornet/cornet_s.py) model into an SNN model by following and modifying the MNIST Pytorch example (https://github.com/NeuromorphicProcessorProject/snn_toolbox/blob/master/examples/mnist_pytorch_INI.py). It was able to successfully port the ONNX model to Keras, build the parsed model, compile the parsed model and evaluate it on the dataset. Midway through building the spiking model, I get the following issue when building the MaxPooling2D layer:

Building spiking model... Building layer: 00ZeroPadding2D_10x134x134 Building layer: 01Conv2D_64x64x64 Building layer: 02ZeroPadding2D_64x66x66 Building layer: 03MaxPooling2D_64x32x32 Traceback (most recent call last): File "snn_test.py", line 114, in main(config_filepath) File "C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\bin\run.py", line 31, in main run_pipeline(config) File "C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\bin\utils.py", line 127, in run_pipeline spiking_model.build(parsed_model, testset) File "C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\simulation\utils.py", line 438, in build self.setup_layers(batch_shape) File "C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\simulation\utils.py", line 780, in setup_layers self.add_layer(layer) File "C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\simulation\target_simulators\INI_temporal_mean_rate_target_sim.py", line 89, in add_layer self._spiking_layers[layer.name] = spike_layer(inbound) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 969, in call return self._functional_construction_call(inputs, args, kwargs, File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1107, in _functional_construction_call outputs = self._keras_tensor_symbolic_call( File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 840, in _keras_tensor_symbolic_call return self._infer_output_signature(inputs, args, kwargs, input_masks) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 880, in _infer_output_signature outputs = call_fn(inputs, *args, *kwargs) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in call result = self._call(args, kwds) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\def_function.py", line 763, in _initialize self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\function.py", line 3050, in _get_concrete_function_internal_garbage_collected graphfunction, = self._maybe_define_function(args, kwargs) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\function.py", line 3444, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\function.py", line 3279, in _create_graph_function func_graph_module.func_graph_from_py_func( File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 999, in func_graph_from_py_func func_outputs = python_func(*func_args, func_kwargs) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\def_function.py", line 672, in wrapped_fn out = weak_wrapped_fn().wrapped(*args, *kwds) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\eager\function.py", line 3971, in bound_method_wrapper return wrapped_fn(args, kwargs) File "C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 986, in wrapper raise e.ag_error_metadata.to_exception(e) ValueError: in user code:

C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\simulation\backends\inisim\temporal_mean_rate_tensorflow.py:506 decorator  *
    self.impulse = call(self, x)
C:\Anaconda\envs\v2e-env\lib\site-packages\snntoolbox\simulation\backends\inisim\temporal_mean_rate_tensorflow.py:833 call  *
    _, max_idxs = tf.nn.max_pool_with_argmax(
C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\util\dispatch.py:206 wrapper  **
    return target(*args, **kwargs)
C:\Anaconda\envs\v2e-env\lib\site-packages\tensorflow\python\ops\nn_ops.py:4871 max_pool_with_argmax_v2
    raise ValueError("Data formats other than 'NHWC' are not yet supported")

ValueError: Data formats other than 'NHWC' are not yet supported

The config file is:

[paths]
path_wd = C:\Users\Jakob\Projects\v2e_vonenet\snn\runs\1628587342.5728104
dataset_path = C:\Users\Jakob\Projects\v2e_vonenet\snn\runs\1628587342.5728104
filename_ann = pytorch_cnn

[tools]
evaluate_ann = True
normalize = True

[simulation]
simulator = INI
duration = 50
num_to_test = 1000
batch_size = 1
keras_backend = tensorflow

[input]
model_lib = pytorch

[output]
plot_vars = {'spiketrains', 'v_mem', 'correlation', 'error_t', 'activations', 'spikerates'}

The versions of the packages are:

rbodo commented 3 years ago

Hi Jakob,

The spiking maxpool implementation uses the tensorflow max_pool_with_argmax_v2 function, which only supports dimension ordering "channels last". Can you set up your input model with channels last?

JakobVokac commented 3 years ago

Yes. I rearranged the dataset and set the ordering to channels_last and that fixed it, thank you! Is there any specific reason the example says the parser requires channels_first? Could it lead to any other errors?

rbodo commented 3 years ago

The external onnx2keras tool that is used under the hood for parsing your model only supported channels_first when I first implemented this frontend. Perhaps this restriction has been lifted now. If you don't get an error during parsing and the accuracy of the parsed model is the same as of the original model, then you should be fine.