mlech26l / ncps

PyTorch and TensorFlow implementation of NCP, LTC, and CfC wired neural models
https://www.nature.com/articles/s42256-020-00237-3
Apache License 2.0
1.86k stars 297 forks source link

Compatibility with TFLite fused LSTM operations #21

Closed rallen10 closed 2 years ago

rallen10 commented 2 years ago

This is a follow-on from issue https://github.com/mlech26l/keras-ncp/issues/20, but with broader implications and applications. A related issue is discussed here

I am attempting to deploy an NCP model on specialized hardware. To do so, I need to convert a trained NCP model to .tflite. While it seems possible to run a straightforward conversion without error, the resulting .tflite model does not result in a fused-lstm operator

It might be too large of an ask, but is there any possibility of supporting the fused LSTM API? I am a bit out of my depth on how to even begin that process.

In order to demonstrate what I hope to achieve, and what is not working, I've included a zip file with the following contents:

hello_ncp.zip

To run a successful training and tflite conversion (because it uses a vanilla LSTM instead of NCP):

conda env create -f environment.yml
conda activate hello_ncp
python train_hello_ncp_tflite.py  --substitute-lstm    # uses LSTM instead of NCP, saves to hello_ncp.tflite
git clone git@github.com:GreenWaves-Technologies/gap_sdk.git    # toolset of deploying tflite model specialized hardware
source ./gap_sdk/configs/ai_deck.sh 
nntool
open hello_ncp.tflite

The following will fail at the last step because the NCP was not successfully converted to a fused LSTM operation during the tflite conversion

conda env create -f environment.yml
conda activate hello_ncp
python train_hello_ncp_tflite.py   # uses NCP, saves to hello_ncp.tflite
git clone git@github.com:GreenWaves-Technologies/gap_sdk.git    # toolset of deploying tflite model specialized hardware
source ./gap_sdk/configs/ai_deck.sh 
nntool
set debug true
open hello_ncp.tflite

Result (I can provide further traceback if helpful):

ValueError: no handler found for WHILE
EXCEPTION of type 'ValueError' occurred with message: 'no handler found for WHILE'
rallen10 commented 2 years ago

Based on my understanding of the TFLite RNN conversion process, it seems that there must be key differences between LTCCell definition and this example of a custom RNN cell named LSTMCellSimple that was successfully converted to fused-LSTM operation. Perhaps the missing piece of the LTCCell is the "conversion logic" (see the example "conversion logic" here).

However I don't have enough insight into the either the LTCCell or LSTMCellSimple implementations to discern where the salient differences are; let alone the "conversion logic" which isn't present for the LTCCell

mlech26l commented 2 years ago

My knowledge about tflite is rather limited.

After reading the RNN conversion guide, it seems that they only support LSTMs (and variants of it by doing some custom conversion to the fused LSTM op).

As the NCP is quite different from an LSTM, there is no simple conversion possible.

The code you provided seems to crash when parsing the tf.while loop used inside keras.RNN. Therefore, one possible workaround could be to define a single step model (i.e, avoid using keras.RNN). You would then have to loop over the sequence in your target device application code.

# ... Training code for the NCP

in_x = keras.Input(shape=(N_INPUTS,))
in_h = keras.Input(shape=(ncp_cell.state_size,))
ncp_out, ncp_state = ncp_cell(in_x, [in_h])
ncp_standalone = keras.Model(inputs=[in_x, in_h], outputs=[ncp_out, ncp_state[0]])

@tf.function
def run_model(x, h):
    return ncp_standalone([x, h])

concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([args.batch_size, N_INPUTS], ncp_model.inputs[0].dtype),
    tf.TensorSpec(
        [args.batch_size, ncp_cell.state_size], ncp_model.inputs[0].dtype
    ),
)

# ... export to tflite
rallen10 commented 2 years ago

Thank you for your response! I've been testing and debugging this solution all day. I've successfully implemented the single-step model, converted it to TFLite, and verified that it outputs the same values during inference as the n_step model; however I am still running into an error (albeit a new one) when I attempt to run nntool. I believe the error has to do with the nntool implementation, so I won't ask about it here, but I am going to leave this issue open for the time being if that is alright

mlech26l commented 2 years ago

Ok, sure

rallen10 commented 2 years ago

related nntool problem resolved here

Closing...