Samsung / ONE

On-device Neural Engine
Other
426 stars 151 forks source link

[onert-micro] how to support on-device training on model with GRU ? #13365

Open chunseoklee opened 1 month ago

chunseoklee commented 1 month ago

GRU operation in circle can be defined in two ways. During conversion from Keras, it may be converted into :

IMHO, onert-micro is not ready to handle training on multi subgraph.

chunseoklee commented 1 month ago

@BalyshevArtem PTAL

BalyshevArtem commented 1 month ago
  • Single "Custom" GRU operation as in (onert-micro)

I think better to use custom GRU. It also will have better latency and memory consumption effect. And in my opinions easier to support (maybe I am wrong).

chunseoklee commented 4 weeks ago

Here is a reference GRU model and fused GRU model by #13602

gru_fused.zip

tflite model is generated by the following code :

  import tensorflow as tf
  from tensorflow import keras
  from tensorflow.keras import regularizers
  import numpy as np

  adapt_data = np.array([[0., 7., 4. , 0.5],
                         [2., 9., 6. , -0.5],
                         [0., 7., 4. , -0.5],
                         [2., 9., 6. , 0.5]], dtype='float32')
  #normalization_layer.adapt(adapt_data)
  classes = 4
  activation = 'tanh'
  model = tf.keras.models.Sequential([
      tf.keras.Input(shape=(10,4)),
      normalization_layer,
      tf.keras.layers.GRU(units=20, activation=activation, use_bias=True, bias_initializer="ones"),
      tf.keras.layers.Dense(classes, activation='softmax')
  ])

  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))

  model.summary()

  run_model = tf.function(lambda x: model(x))

  # This is important, let's fix the input size.
  BATCH_SIZE = 1
  X = 10
  Y = 4
  concrete_func = run_model.get_concrete_function(
      tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))

  # model directory.
  MODEL_DIR = "keras_model"
  model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

  converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
  converter.experimental_new_converter = True
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                         ]
  #converter.optimizations = [tf.lite.Optimize.DEFAULT]
  converted_model = converter.convert()
  save_to = "GRU.tflite"
  if save_to is not None:
      with open(save_to, 'wb') as tf_lite_file:
          tf_lite_file.write(converted_model)

and apply #13625

chunseoklee commented 3 weeks ago

Let's try to train GRU operation with model at https://github.com/Samsung/ONE/issues/13365#issuecomment-2272718661

chunseoklee commented 3 weeks ago
BalyshevArtem commented 2 weeks ago

Training result

There is training result for #13737

Model obtained from:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
import numpy as np

classes = 4
activation = 'tanh'
model = tf.keras.models.Sequential([
  tf.keras.Input(shape=(60,3)),
  tf.keras.layers.GRU(units=60, activation=activation),
  tf.keras.layers.Dense(classes, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
model.summary()
run_model = tf.function(lambda x: model(x))

# This is important, let's fix the input size.
BATCH_SIZE = 1
X = 60
Y = 3
concrete_func = run_model.get_concrete_function(
  tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))

# model directory.
MODEL_DIR = "keras_model"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                     ]
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converted_model = converter.convert()
save_to = "gru_stick.tflite"
if save_to is not None:
  with open(save_to, 'wb') as tf_lite_file:
      tf_lite_file.write(converted_model)

Training data is data for a targeted model. In this experiment, 1000 random samples were used for training and 150 for testing from the original training data. Task is a classification task. I used cross entropy as loss and accuracy as metric. In order to make sure that the GRU layer is learning, we first train only the last FullyConnected layer in the initial model, and then we train both the FullyConnected layer and the GRU layer.

Initial values:

Test Average ACCURACY = 0.34
Test Average CROSS ENTROPY = 2.54871

Train only last (FullyConnected) layer:

Test Average ACCURACY = 0.61
Test Average CROSS ENTROPY = 0.898501

Train last FullyConnected + GRU:

Test Average ACCURACY = 0.72
Test Average CROSS ENTROPY = 0.751191

Thus, it can be seen that the GRU layer is trained and helps to achieve better results in this task.