axinc-ai / ailia-models-tflite

Quantized version of model library
23 stars 1 forks source link

ADD ESPCN #16

Closed kyakuno closed 2 years ago

kyakuno commented 2 years ago

https://github.com/keras-team/keras-io/blob/master/examples/vision/super_resolution_sub_pixel.py

kyakuno commented 2 years ago

super_resolution_sub_pixel.pyの末尾に下記を追加することで、tfliteへの変換と量子化が可能。

#save model
model.save("saved_model_espcn")

#convert to float model
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model_espcn")
tflite_quant_model = converter.convert()
with open("espcn.tflite", 'wb') as o_:
    o_.write(tflite_quant_model)

#prepare calibration data
validation_data_set=[]
for index, test_img_path in enumerate(test_img_paths[50:60]):
    img = load_img(test_img_path)
    lowres_input = get_lowres_image(img, upscale_factor)
    ycbcr = img.convert("YCbCr")
    y, cb, cr = ycbcr.split()
    y = img_to_array(y)
    y = y.astype("float32") / 255.0
    input = np.expand_dims(y, axis=0)
    validation_data_set.append(input)

#quantize
def representative_dataset_gen():
  for i in range(len(validation_data_set)):
    yield [validation_data_set[i]]

converter = tf.lite.TFLiteConverter.from_saved_model("saved_model_espcn")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_model = converter.convert()

with open("espcn_quant.tflite", 'wb') as o_:
    o_.write(tflite_quant_model)