mlech26l / ncps

PyTorch and TensorFlow implementation of NCP, LTC, and CfC wired neural models
https://www.nature.com/articles/s42256-020-00237-3
Apache License 2.0
1.86k stars 297 forks source link

Compatibility with GAP8 nntool #20

Closed rallen10 closed 2 years ago

rallen10 commented 2 years ago

I am working to port a NCP model to a micro-robotics platform; specifically a GAP8 processor onboard the Crazyflie AI-Deck. The GAP8 developers have provided a toolset and workflow for porting keras .h5 models into format that can be run onboard the GAP8 processor, however I can't get it to work with a keras-ncp model.

I can successfully convert my trained model to .tflite using a slightly modified version of h5_to_tflite.py. However, when I attempt to open the .tflite model with nntool, I get the error

Traceback (most recent call last):
  File "/home/ross/miniconda3/envs/hello_ncp/lib/python3.7/site-packages/cmd2/cmd2.py", line 1661, in onecmd_plus_hooks
    stop = self.onecmd(statement, add_to_history=add_to_history)
  File "/home/ross/miniconda3/envs/hello_ncp/lib/python3.7/site-packages/cmd2/cmd2.py", line 2081, in onecmd
    stop = func(statement)
  File "/home/ross/miniconda3/envs/hello_ncp/lib/python3.7/site-packages/cmd2/decorators.py", line 223, in cmd_wrapper
    return func(cmd2_app, args)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/interpreter/commands/open.py", line 118, in do_open
    self.__open_graph(args)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/interpreter/commands/open.py", line 92, in __open_graph
    G = create_graph(graph_file, opts=opts)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/importer.py", line 52, in create_graph
    graph = importer.create_graph(filename, opts)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/tflite.py", line 103, in create_graph
    self._import_tflite_graph(G, model, opts)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/tflite.py", line 154, in _import_tflite_graph
    self._provisional_outputs, opts)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/tflite.py", line 245, in _import_nodes
    node, all_nodes=all_nodes, G=G, opts=opts, importer=self)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/handlers/handler.py", line 65, in handle
    return ver_handle(node, **kwargs)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/handlers/backend/fill.py", line 56, in version_1
    return cls._common(node, **kwargs)
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/handlers/backend/fill.py", line 40, in _common
    shape = list(cls._verify_constant(inputs[0]))
  File "/home/ross/Projects/AIIA/crazyflie/gap_sdk/tools/nntool/importer/tflite2/handlers/backend_handler.py", line 63, in _verify_constant
    raise ValueError("expected node %s to be constant input" % inp[0].name)
ValueError: expected node CONCATENATION_0_6 to be constant input
EXCEPTION of type 'ValueError' occurred with message: 'expected node CONCATENATION_0_6 to be constant input'

I've opened an issue with the GAP8 developers in the hopes the problem (and fix) might be on their side (see here), however I am also opening an issue here because it is not clear where the fundamental problem lies. My suspicion is that the incompatibility lies in the LTCCell definition, but I don't know enough about either codebase to debug it on my own.

mlech26l commented 2 years ago

Does the conversion tool nntool support recurrent neural networks?

I am asking because the error seems to be caused by an CONCATENATION op. There is no such op in the LTCCell definition but it is often used to implement RNNs.

Were you able to convert a LSTM or a vanilla RNN with nntool?

rallen10 commented 2 years ago

Good question. Let me try to train a vanilla RNN and see if I can make the conversion

rallen10 commented 2 years ago

@mlech26l good catch. The same problem arises when I try to use a keras.layers.LSTM model. Therefore it seems the problem is on the gap_sdk/nntool side of things, not keras-ncp. I will close this issue for now; thank you for the help.

For posterity, I tried both keras.layers.LSTM and keras.layers.SimpleRNN, both producing the same error as above (with a slightly different node of CONCATENATION_0_8)

mlech26l commented 2 years ago

Glad I could help