tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Unexpected Inference Time and Model Size for TensorFlow Lite and Pruned Models #1135

Open experimentsym3 opened 2 weeks ago

experimentsym3 commented 2 weeks ago

I converted my TensorFlow models (.h5 format) to TensorFlow Lite, including quantized and pruned versions. Note: my model is squeezenet

Script to Save Lite Model:

def create_tflite_model(original_model):
    converter = tf.lite.TFLiteConverter.from_keras_model(original_model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    lite_model = converter.convert()
    return lite_model

def save_tflite_model(model, method):
    path_ = "../Lite"
    tflite_models_dir = pathlib.Path(path_)
    tflite_models_dir.mkdir(exist_ok=True, parents=True)
    namee = f"{method}.tflite"
    tflite_model_file = tflite_models_dir / namee
    tflite_model_file.write_bytes(model)

lite_model = create_tflite_model(original_model)
save_tflite_model(lite_model, "lite")

Script to Create Pruned Version:

def create_pruned_model(original_model, data):
    x_train, x_val, x_test, y_train, y_val, y_test = data
    batch_size = 64
    epochs = 200

    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    num_samples = x_train.shape[0]
    end_step = np.ceil(num_samples / batch_size).astype(np.int32) * epochs
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.30,
                                                                 final_sparsity=0.60,
                                                                 begin_step=0,
                                                                 end_step=end_step,
                                                                 frequency=100)
    }

    model_for_pruning = prune_low_magnitude(original_model, **pruning_params)
    model_for_pruning.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001, beta_1=0.9),
                              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                              metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])

    callbacks = [
        keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True, verbose=0),
        tfmot.sparsity.keras.UpdatePruningStep(),
        SparsityLoggingCallback(log_frequency=1)  
    ]

    model_for_pruning.fit(x_train, y_train,
                          batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val),
                          callbacks=callbacks)

    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    pruned_model = create_tflite_model(model_for_export)
    return pruned_model

pruned_model = create_pruned_model(original_model, data)
save_tflite_model(pruned_model, "pruned")

Script to Create Quantized Version:

def create_quant_model(original_model):
    converter = tf.lite.TFLiteConverter.from_keras_model(original_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    dynamic_range_quant_model = converter.convert()
    return dynamic_range_quant_model

quant_model = create_quant_model(original_model)
save_tflite_model(quant_model, "quantized")

I compared the performance of the original, Lite, quantized, and pruned models in terms of validation accuracy, inference time, and model size. For model size, I used gzip for all models.

My Questions:

  1. Inference Time:

    • I observe that: Original model > Lite model = Pruned model > Quantized model
    • Question 1 : Why does the pruned model have the same inference time as the Lite model? I expect from pruned model to have less inference time than lite model.
  2. Model Size:

    • I observe that: Original model = Lite model > Pruned model > Quantized model
    • Question 2: Why does the Lite model have the same size as the original model? I expect from it to be less than original one.

Update:

I installed TensorFlow Flex to support operations not natively supported by TensorFlow Lite. I have provided a minimal, non-working example to illustrate my second question about model size: Why does the Lite model have the same size as the original model? I expected it to be smaller than the original.

I am unsure how to resolve this issue. Interestingly, when I perform the same steps with a simple CNN model, it works perfectly (Lite is smaller than Original and the code works without error). However, it fails with this SqueezeNet example.

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Concatenate, Add, GlobalAveragePooling2D, Dropout
from tensorflow.keras import backend as K
import os
import tempfile
import zipfile
import pathlib
import numpy as np
from tensorflow.keras.utils import register_keras_serializable

# Function to print library versions
def print_library_versions():
    print(f"TensorFlow Version: {tf.__version__}")

# Print library versions
print_library_versions()

# Model definition functions
def SqueezeNetSimple(input_shape, num_classes, use_bypass=False, dropout_rate=None):
    input_img = Input(shape=input_shape)
    x = Conv2D(16, (3, 3), activation='relu', padding='same', name='conv1')(input_img)
    x = MaxPooling2D(pool_size=(2, 2), name='maxpool1')(x)
    x = create_fire_module(x, 8, name='fire2')
    x = create_fire_module(x, 8, name='fire3', use_bypass=use_bypass)
    if dropout_rate: x = Dropout(dropout_rate)(x)
    x = Conv2D(num_classes, (1, 1), activation='relu', padding='same', name='conv2')(x)
    x = GlobalAveragePooling2D(name='avgpool2')(x)
    return Model(inputs=input_img, outputs=x)

def create_fire_module(x, nb_squeeze_filter, name, use_bypass=False):
    nb_expand_filter = 4 * nb_squeeze_filter
    squeeze = Conv2D(nb_squeeze_filter, (1, 1), activation='relu', padding='same', name='%s_squeeze' % name)(x)
    expand_1x1 = Conv2D(nb_expand_filter, (1, 1), activation='relu', padding='same', name='%s_expand_1x1' % name)(squeeze)
    expand_3x3 = Conv2D(nb_expand_filter, (3, 3), activation='relu', padding='same', name='%s_expand_3x3' % name)(squeeze)

    axis = -1 if K.image_data_format() == 'channels_last' else 1
    x_ret = Concatenate(axis=axis, name='%s_concatenate' % name)([expand_1x1, expand_3x3])

    if use_bypass:
        x_ret = Add(name='%s_concatenate_bypass' % name)([x_ret, x])

    return x_ret

# Function to get gzipped model size
def get_gzipped_model_size(model_name):
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(model_name)
    return os.path.getsize(zipped_file)  # Size in bytes

# Function to save TFLite model
def save_tflite_model(model, a):
    path_ = "../source"
    tflite_models_dir = pathlib.Path(path_)
    tflite_models_dir.mkdir(exist_ok=True, parents=True)
    namee = f"squeezenet_opportunity_{a}_sil.tflite"
    tflite_model_file = tflite_models_dir / namee
    tflite_model_file.write_bytes(model)
    print("Model saved: ", namee)
    return tflite_model_file

# Function to create and save TFLite model
def get_saved_model1(model_dir, a):
    converter = tf.lite.TFLiteConverter.from_saved_model(model_dir)
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    # Disable resource variables
    converter.experimental_enable_resource_variables = False
    # Enable debugging and detailed logging
    converter.experimental_new_converter = True
    converter.experimental_new_quantizer = True
    lite_model = converter.convert()
    model_path = save_tflite_model(lite_model, a)
    return model_path

# Example data (replace with your actual data)
x_train = np.random.rand(100, 30, 45, 1)
y_train = np.random.randint(0, 5, 100)  # Dummy labels for 5 classes
num_classes = 5

# Create the simplified model
model = SqueezeNetSimple(x_train.shape[1:], num_classes, use_bypass=True, dropout_rate=0.5)

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model with dummy data
model.fit(x_train, y_train, epochs=1, batch_size=16)  # Adjust epochs and batch_size as needed

# Save the model in TensorFlow SavedModel format
saved_model_path = '../models/Original/squeezenet/saved_squeezenet_simple'
tf.saved_model.save(model, saved_model_path)

# Convert the SavedModel to TFLite
model_path = get_saved_model1(saved_model_path, "m1_simple")
print("Model name: ", model_path)
print("Model size: ", os.path.getsize(model_path))
print("Model size gzip: ", get_gzipped_model_size(model_path))

print("The end")

# Load TFLite model and allocate tensors with Flex delegate
def load_tflite_model_with_flex(model_path, flex_delegate_path):
    try:
        interpreter = tf.lite.Interpreter(model_path=model_path,
                                          experimental_delegates=[tf.lite.experimental.load_delegate(flex_delegate_path)])
        interpreter.allocate_tensors()
        return interpreter
    except Exception as e:
        print(f"Failed to load delegate: {e}")
        return None

# Specify the correct path to the Flex delegate library
flex_delegate_path = 'bazel-bin/tensorflow/lite/delegates/flex/libtensorflowlite_flex_delegate.so'

if not os.path.exists(flex_delegate_path):
    raise FileNotFoundError(f"Flex delegate library not found at {flex_delegate_path}")

interpreter1 = load_tflite_model_with_flex(model_path, flex_delegate_path)

if interpreter1:
    # Test the model on random input data
    input_details = interpreter1.get_input_details()
    output_details = interpreter1.get_output_details()

    input_shape = input_details[0]['shape']
    input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    interpreter1.set_tensor(input_details[0]['index'], input_data)
    interpreter1.invoke()

    output_data = interpreter1.get_tensor(output_details[0]['index'])
    print("Output from model 1: ", output_data)

print("The end")

This is the output i got (I use macbook pro with m3 pro chip):

(tf_env) username@User-MBP source % python myscript.py TensorFlow Version: 2.16.1 7/7 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.0782 - loss: 3.1644
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1719010068.685977 2618997 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1719010068.686002 2618997 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. 2024-06-22 01:47:48.686209: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: ../models/Original/squeezenet/saved_squeezenet_simple 2024-06-22 01:47:48.687045: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve } 2024-06-22 01:47:48.687053: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: ../models/Original/squeezenet/saved_squeezenet_simple 2024-06-22 01:47:48.696450: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled 2024-06-22 01:47:48.697860: I tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle. 2024-06-22 01:47:48.735235: I tensorflow/cc/saved_model/loader.cc:218] Running initialization op on SavedModel bundle at path: ../models/Original/squeezenet/saved_squeezenet_simple 2024-06-22 01:47:48.746486: I tensorflow/cc/saved_model/loader.cc:317] SavedModel load for tags { serve }; Status: success: OK. Took 60277 microseconds. 2024-06-22 01:47:48.756733: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY to enable. loc(fused["ReadVariableOp:", "functional_1_1/conv1_1/Reshape/ReadVariableOp@__inference_serving_default_3237"]): error: missing attribute 'value' LLVM ERROR: Failed to infer result type(s). zsh: abort python myscript.py