tensorflow / tflite-micro

Infrastructure to enable deployment of ML models to low-power resource-constrained embedded targets (including microcontrollers and digital signal processors).
Apache License 2.0
1.89k stars 814 forks source link

RNN tflite model doesn't work when deployed on a microcontroller #1570

Closed jkaldal closed 1 year ago

jkaldal commented 1 year ago

System information

Describe the feature and the current behavior/state. RNN tflite model not working when deployed on a microcontroller unless it is unrolled which increases the size of the model tenfold.

I want to compare different types of RNN tflite-micro models. I have created a custom RNN cell that I want to compare with the LSTM cell, GRU cell, and SimpleRNN cell.

The following code shows how the network is created and converted to tflite model

import tensorflow as tf

# create the rnn cell using one of these RNN cells
units = <value>
#RNNcell = tf.keras.layers.SimpleRNNCell(units)
#RNNcell = tf.keras.layers.GRUCell(units)
RNNcell = tf.keras.layers.LSTMCell(units)

# create the model using the cell in an RNN layer
timesteps = <value>
n_features = <value>
dropout = <value>
dense_units = <value>
outputs = <value>

model = tf.keras.Sequential([
    tf.keras.layers.RNN(RNNcell, input_shape=(timesteps, n_features),
    tf.keras.layers.Dense(dense_units, activation='relu'),
    tf.keras.layers.Dense(outputs, activation='softmax', name="output")
])

# then create model for inference with a fixed batch size
model = tf.keras.Sequential([
    tf.keras.layers.RNN(RNNcell, batch_size=1, input_shape=(timesteps, n_features),
    tf.keras.layers.Dense(dense_units, activation='relu'),
    tf.keras.layers.Dense(outputs, activation='softmax', name="output")
])

#train the model
EPOCHS = <value>
BATCH_SIZE = <value>
LEARNING_RATE = <value>

opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(
    train_X, train_y, 
    epochs=EPOCHS, 
    validation_data=(validation_X, validation_y), 
    shuffle=False, 
    batch_size=BATCH_SIZE)

# move weights to model with fixed batch_size
model_inference.set_weights(loaded_model.get_weights())
model_inference.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

# save to tflite format
converter = tf.lite.TFLiteConverter.from_keras_model(model_inference)
tfmodel = converter.convert()
open(f'model.tflite', 'wb').write(tfmodel)

## EXTRA
# test operations with tflite interpreter in python
input_data = <value>

interpreter = tf.lite.Interpreter(model_content=tfmodel)
interpreter.allocate_tensors()

_input_index = interpreter.get_input_details()[0]['index']
interpreter.set_tensor(_input_index, input_data)
interpreter.invoke()

The tflite model works in the python interpreter but does not run on a microcontroller (nRF5340) unless the RNN layer is unrolled, which increases the size of the model significantly.

This is how the model looks in Netron and the error from the interpreter on the microcontroller indicates that INT32 is not supported for ADD in the WHILE operator.

Type INT32 (2) not supported
Node ADD (number 0) failed to invoke with status 1
Node WHILE (number 1) failed to invoke with status 1

Netron vizualization of network

Will this change the current api? How? don't know

Who will benefit with this feature? Everybody who needs and thought RNNs were supported in Tensorflow Lite Micro.

Any Other info. Similar to issue #907 but in that issue that person is unrolling.

Also, a tflite model with created with tf.keras.layers.LSTM creates the operator unidirectionalLSTM but the model is created with a tf.keras.layers.LSTMCell in a tf.keras.layers.RNN layer then it has a reshape and a while operator instead.

Is it possible to use the unidirectionalLSTM operator for the other cells as well somehow?

jkaldal commented 1 year ago

When I run the analyser on the tflite model:

tf.lite.experimental.Analyzer.analyze(model_path='model.tflite')

this is what I get

=== model.tflite ===

Your TFLite model has '3' subgraph(s). In the subgraph description below,
T# represents the Tensor numbers. For example, in Subgraph#0, the RESHAPE op takes
tensor #0 and tensor #11 as input and produces tensor #13 as output.

Subgraph#0 main(T#0) -> [T#23]
  Op#0 RESHAPE(T#0, T#11) -> [T#13]
  Op#1 WHILE(T#3, T#3, T#1, T#4, T#4, T#13, Cond: Subgraph#1, Body: Subgraph#2) -> [T#14, T#15, T#16, T#17, T#18, T#19]
  Op#2 STRIDED_SLICE(T#16, T#8, T#9, T#10) -> [T#20]
  Op#3 FULLY_CONNECTED(T#20, T#12, T#5) -> [T#21]
  Op#4 FULLY_CONNECTED(T#21, T#7, T#6) -> [T#22]
  Op#5 SOFTMAX(T#22) -> [T#23]

Tensors of Subgraph#0
  T#0(serving_default_lstm_input:0) shape:[1, 128, 6], type:FLOAT32
  T#1(sequential_1/lstm/TensorArrayV2_1) shape:[1, 1, 100], type:FLOAT32 RO 400 bytes
  T#2(sequential_1/lstm/strided_slice_1) shape:[], type:INT32 RO 4 bytes
  T#3(sequential_1/lstm/time) shape:[], type:INT32 RO 4 bytes
  T#4(sequential_1/lstm/zeros) shape:[1, 100], type:FLOAT32 RO 400 bytes
  T#5(sequential_1/dense_1/BiasAdd/ReadVariableOp) shape:[40], type:FLOAT32 RO 160 bytes
  T#6(sequential_1/output/BiasAdd/ReadVariableOp) shape:[6], type:FLOAT32 RO 24 bytes
  T#7(sequential_1/output/MatMul) shape:[6, 40], type:FLOAT32 RO 960 bytes
  T#8(sequential_1/lstm/strided_slice_3) shape:[3], type:INT32 RO 12 bytes
  T#9(sequential_1/lstm/strided_slice_31) shape:[3], type:INT32 RO 12 bytes
  T#10(sequential_1/lstm/strided_slice_32) shape:[3], type:INT32 RO 12 bytes
  T#11(sequential_1/lstm/transpose) shape:[3], type:INT32 RO 12 bytes
  T#12(sequential_1/dense_1/MatMul) shape:[40, 100], type:INT8 RO 4000 bytes
  T#13(sequential_1/lstm/transpose1) shape:[128, 1, 6], type:FLOAT32
  T#14(sequential_1/lstm/while) shape:[], type:INT32
  T#15(sequential_1/lstm/while1) shape:[], type:INT32
  T#16(sequential_1/lstm/while2) shape:[1, 1, 100], type:FLOAT32
  T#17(sequential_1/lstm/while3) shape:[1, 100], type:FLOAT32
  T#18(sequential_1/lstm/while4) shape:[1, 100], type:FLOAT32
  T#19(sequential_1/lstm/while5) shape:[128, 1, 6], type:FLOAT32
  T#20(sequential_1/lstm/strided_slice_33) shape:[1, 100], type:FLOAT32
  T#21(sequential_1/dense_1/MatMul;sequential_1/dense_1/Relu;sequential_1/dense_1/BiasAdd) shape:[1, 40], type:FLOAT32
  T#22(sequential_1/output/MatMul;sequential_1/output/BiasAdd) shape:[1, 6], type:FLOAT32
  T#23(StatefulPartitionedCall:0) shape:[1, 6], type:FLOAT32

Subgraph#1 sequential_1/lstm/while_cond(T#1_0, T#1_1, T#1_2, T#1_3, T#1_4, T#1_5) -> [T#1_7]
  Op#0 LESS(T#1_1, T#1_6) -> [T#1_7]

Tensors of Subgraph#1
  T#1_0(arg0) shape:[], type:INT32
  T#1_1(arg1) shape:[], type:INT32
  T#1_2(arg2) shape:[1, 1, 100], type:FLOAT32
  T#1_3(arg3) shape:[1, 100], type:FLOAT32
  T#1_4(arg4) shape:[1, 100], type:FLOAT32
  T#1_5(arg5) shape:[128, 1, 6], type:FLOAT32
  T#1_6(sequential_1/lstm/strided_slice_11) shape:[], type:INT32 RO 4 bytes
  T#1_7(sequential_1/lstm/while/Less) shape:[], type:BOOL

Subgraph#2 sequential_1/lstm/while_body(T#2_0, T#2_1, T#2_2, T#2_3, T#2_4, T#2_5) -> [T#2_13, T#2_11, T#2_31, T#2_30, T#2_28, T#2_5]
  Op#0 ADD(T#2_1, T#2_6) -> [T#2_11]
  Op#1 FULLY_CONNECTED(T#2_3, T#2_10, T#-1) -> [T#2_12]
  Op#2 ADD(T#2_0, T#2_6) -> [T#2_13]
  Op#3 GATHER(T#2_5, T#2_1) -> [T#2_14]
  Op#4 FULLY_CONNECTED(T#2_14, T#2_9, T#-1) -> [T#2_15]
  Op#5 ADD(T#2_15, T#2_12) -> [T#2_16]
  Op#6 ADD(T#2_16, T#2_7) -> [T#2_17]
  Op#7 SPLIT(T#2_6, T#2_17) -> [T#2_18, T#2_19, T#2_20, T#2_21]
  Op#8 LOGISTIC(T#2_18) -> [T#2_22]
  Op#9 LOGISTIC(T#2_19) -> [T#2_23]
  Op#10 MUL(T#2_23, T#2_4) -> [T#2_24]
  Op#11 LOGISTIC(T#2_21) -> [T#2_25]
  Op#12 TANH(T#2_20) -> [T#2_26]
  Op#13 MUL(T#2_22, T#2_26) -> [T#2_27]
  Op#14 ADD(T#2_24, T#2_27) -> [T#2_28]
  Op#15 TANH(T#2_28) -> [T#2_29]
  Op#16 MUL(T#2_25, T#2_29) -> [T#2_30]
  Op#17 RESHAPE(T#2_30, T#2_8) -> [T#2_31]

Tensors of Subgraph#2
  T#2_0(arg0) shape:[], type:INT32
  T#2_1(arg1) shape:[], type:INT32
  T#2_2(arg2) shape:[1, 1, 100], type:FLOAT32
  T#2_3(arg3) shape:[1, 100], type:FLOAT32
  T#2_4(arg4) shape:[1, 100], type:FLOAT32
  T#2_5(arg5) shape:[128, 1, 6], type:FLOAT32
  T#2_6(sequential_1/lstm/while/add/y) shape:[], type:INT32 RO 4 bytes
  T#2_7(sequential_1/lstm/while/lstm_cell/BiasAdd/ReadVariableOp) shape:[400], type:FLOAT32 RO 1600 bytes
  T#2_8(sequential_1/lstm/while/TensorArrayV2Write/TensorListSetItem) shape:[3], type:INT32 RO 12 bytes
  T#2_9(sequential_1/lstm/while/lstm_cell/MatMul1) shape:[400, 6], type:INT8 RO 2400 bytes
  T#2_10(sequential_1/lstm/while/lstm_cell/MatMul_1) shape:[400, 100], type:INT8 RO 40000 bytes
  T#2_11(sequential_1/lstm/while/add) shape:[], type:INT32
  T#2_12(sequential_1/lstm/while/lstm_cell/MatMul_11) shape:[1, 400], type:FLOAT32
  T#2_13(sequential_1/lstm/while/add_1) shape:[], type:INT32
  T#2_14(sequential_1/lstm/while/TensorArrayV2Read/TensorListGetItem;sequential_1/lstm/time) shape:[1, 6], type:FLOAT32
  T#2_15(sequential_1/lstm/while/lstm_cell/MatMul2) shape:[1, 400], type:FLOAT32
  T#2_16(sequential_1/lstm/while/lstm_cell/add) shape:[1, 400], type:FLOAT32
  T#2_17(sequential_1/lstm/while/lstm_cell/BiasAdd) shape:[1, 400], type:FLOAT32
  T#2_18(sequential_1/lstm/while/lstm_cell/split) shape:[1, 100], type:FLOAT32
  T#2_19(sequential_1/lstm/while/lstm_cell/split1) shape:[1, 100], type:FLOAT32
  T#2_20(sequential_1/lstm/while/lstm_cell/split2) shape:[1, 100], type:FLOAT32
  T#2_21(sequential_1/lstm/while/lstm_cell/split3) shape:[1, 100], type:FLOAT32
  T#2_22(sequential_1/lstm/while/lstm_cell/Sigmoid) shape:[1, 100], type:FLOAT32
  T#2_23(sequential_1/lstm/while/lstm_cell/Sigmoid_1) shape:[1, 100], type:FLOAT32
  T#2_24(sequential_1/lstm/while/lstm_cell/mul) shape:[1, 100], type:FLOAT32
  T#2_25(sequential_1/lstm/while/lstm_cell/Sigmoid_2) shape:[1, 100], type:FLOAT32
  T#2_26(sequential_1/lstm/while/lstm_cell/Tanh) shape:[1, 100], type:FLOAT32
  T#2_27(sequential_1/lstm/while/lstm_cell/mul_1) shape:[1, 100], type:FLOAT32
  T#2_28(sequential_1/lstm/while/lstm_cell/add_1) shape:[1, 100], type:FLOAT32
  T#2_29(sequential_1/lstm/while/lstm_cell/Tanh_1) shape:[1, 100], type:FLOAT32
  T#2_30(sequential_1/lstm/while/lstm_cell/mul_2) shape:[1, 100], type:FLOAT32
  T#2_31(sequential_1/lstm/while/TensorArrayV2Write/TensorListSetItem1) shape:[1, 1, 100], type:FLOAT32

---------------------------------------------------------------
Your TFLite model has ‘1’ signature_def(s).

Signature#0 key: 'serving_default'
- Subgraph: Subgraph#0
- Inputs: 
    'lstm_input' : T#0
- Outputs: 
    'output' : T#23

---------------------------------------------------------------
              Model size:      58128 bytes
    Non-data buffer size:       8004 bytes (13.77 %)
  Total data buffer size:      50124 bytes (86.23 %)
          - Subgraph#0  :       6000 bytes (10.32 %)
          - Subgraph#1  :          4 bytes (00.01 %)
          - Subgraph#2  :      44016 bytes (75.72 %)
    (Zero value buffers):        804 bytes (01.38 %)

* Buffers of TFLite model are mostly used for constant tensors.
  And zero value buffers are buffers filled with zeros.
  Non-data buffers area are used to store operators, subgraphs and etc.
  You can find more details from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs
github-actions[bot] commented 1 year ago

"This issue is being marked as stale due to inactivity. Remove label or comment to prevent closure in 5 days."

github-actions[bot] commented 1 year ago

"This issue is being closed because it has been marked as stale for 5 days with no further activity."

TayyabaZainab0807 commented 1 year ago

did you find the solution for this?

jkaldal commented 1 year ago

You have to unroll the network


From: TayyabaZainab0807 @.> Sent: Monday, June 26, 2023 6:02:02 PM To: tensorflow/tflite-micro @.> Cc: jkaldal @.>; Author @.> Subject: Re: [tensorflow/tflite-micro] RNN tflite model doesn't work when deployed on a microcontroller (Issue #1570)

did you find the solution for this?

— Reply to this email directly, view it on GitHubhttps://github.com/tensorflow/tflite-micro/issues/1570#issuecomment-1607783518, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQXGIDLBMCXF3JMP26475LLXNGW7VANCNFSM6AAAAAASDH5JNU. You are receiving this because you authored the thread.Message ID: @.***>

jonnor commented 1 year ago

How does one unroll the network?

jkaldal commented 1 year ago

It is an argument in tf.keras.layers.RNN

Not a good solution for large networks.


From: Jon Nordby @.> Sent: Wednesday, July 12, 2023 11:19:27 AM To: tensorflow/tflite-micro @.> Cc: jkaldal @.>; Author @.> Subject: Re: [tensorflow/tflite-micro] RNN tflite model doesn't work when deployed on a microcontroller (Issue #1570)

How does one unroll the network?

— Reply to this email directly, view it on GitHubhttps://github.com/tensorflow/tflite-micro/issues/1570#issuecomment-1632152316, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQXGIDK7KQ3RQXQFRV2EBZLXPZTZ7ANCNFSM6AAAAAASDH5JNU. You are receiving this because you authored the thread.Message ID: @.***>