google-ai-edge / mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
https://mediapipe.dev
Apache License 2.0
26.7k stars 5.07k forks source link

Save model training progress for converting to int8 later. #5522

Closed hiteshtechshslok closed 1 month ago

hiteshtechshslok commented 1 month ago

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

None

OS Platform and Distribution

Kaggle

Python Version

3.10.13

MediaPipe Model Maker version

Version: 0.2.1.4

Task name (e.g. Image classification, Gesture recognition etc.)

Object detection

Describe the actual behavior

I am training a custom mode from roboflow and it is training well, i want to save the model in some format like h5 or something from which i can keep the weights to re train model for QAT process for int8 conversion.

Describe the expected behaviour

I am unable to find any example which can save training progress given examples just says this does not found or something like that.

Standalone code/steps you may have used to try to get what you need

pip install mediapipe-model-maker

import os
import json
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import object_detector
train_dataset_path = "/kaggle/working/merged_db/train"
validation_dataset_path = "/kaggle/working/merged_db/valid"

with open(os.path.join(train_dataset_path, "labels.json"), "r") as f:
  labels_json = json.load(f)
for category_item in labels_json["categories"]:
  print(f"{category_item['id']}: {category_item['name']}")

train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir="/train")
validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir="/validation")
print("train_data size: ", train_data.size)
print("validation_data size: ", validation_data.size)

spec = object_detector.SupportedModels.MOBILENET_V2
hparams = object_detector.HParams(export_dir='exported_model', epochs=100)
options = object_detector.ObjectDetectorOptions(
    supported_model=spec,
    hparams=hparams
)

model = object_detector.ObjectDetector.create(
    train_data=train_data,
    validation_data=validation_data,
    options=options)

loss, coco_metrics = model.evaluate(validation_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")

model.export_model()

qat_hparams = object_detector.QATHParams(learning_rate=0.009, batch_size=16, epochs=30, decay_steps=10, decay_rate=0.96)
model.restore_float_ckpt()

model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)
qat_loss, qat_coco_metrics = model.evaluate(validation_data)
print(f"QAT validation loss: {qat_loss}")
print(f"QAT validation coco metrics: {qat_coco_metrics}")

model.export_model('model_int8_qat.tflite')

Other info / Complete Logs

model.save_checkpoint("/kaggle/working/exp/")

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[20], line 2
      1 # model.model.export_saved_model()
----> 2 model.save_checkpoint("/kaggle/working/exp/")

AttributeError: 'ObjectDetector' object has no attribute 'save_checkpoint'

Unable to save the model in any format and restore from them to perform this step without having to train the model to complete

qat_hparams = object_detector.QATHParams(learning_rate=0.009, batch_size=16, epochs=30, decay_steps=10, decay_rate=0.96)
model.restore_float_ckpt()

model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)
qat_loss, qat_coco_metrics = model.evaluate(validation_data)
print(f"QAT validation loss: {qat_loss}")
print(f"QAT validation coco metrics: {qat_coco_metrics}")
HTerminal commented 1 month ago

Did any one found any solution to this problem.

kuaashish commented 1 month ago

Hi @joezoug,

Do you have any suggestions for this issue? It seems they are looking to save in H5 format or a similar format to retain the weights for retraining in the QAT process for INT8 conversion but are encountering a not found error.

Thank you!!

HTerminal commented 1 month ago

Hello @kuaashish

Any update on how this can be done!

Thank you

joezoug commented 1 month ago

Hi @HTerminal,

If you call ObjectDetector.create(), it actually saves a floating point model checkpoint at the end of training with a call to the _save_float_ckpt api.

You can then load the floating point model after you run qat by using the restore_float_ckpt api.

Here is a guide on this workstream: https://ai.google.dev/edge/mediapipe/solutions/customization/object_detector#quantization_aware_training_int8_quantization.

Hope that helps!

HTerminal commented 1 month ago

Hello @joezoug Thanks for the reponse But if you will see here it says that it requires you to run the create method and perform the training first in the same session then you can call the restore any times I just want to run the traning once and come back again and then resue the trained mode

Check here - https://ai.google.dev/edge/mediapipe/solutions/customization/object_detector#quantization_aware_training_int8_quantization:~:text=After%20running%20the%20create%20method

joezoug commented 1 month ago

@HTerminal,

I see thanks for clarifying. Unfortunately we don't have a direct API to do this, but you can write some custom code to get the behavior that you want. We open source our code so you can refer to the create method and make the same function calls except omit the _train_model and _save_float_ckpt calls. Once you have initialized an ObjectDetector instance this way without the training step, you can call restore_float_ckpt to load the model.

Note the method I detailed above requires you to re-use the same hparms.export_dir as the first training run where you save the float checkpoint. If you want to avoid this restriction, you can customize the code in the restore_float_ckpt method:

self._model.load_checkpoint(
       <INSERT CUSTOM PATH>,
        include_last_layer=True,
    )
    self._model.compile()
    self._is_qat = False
github-actions[bot] commented 1 month ago

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 1 month ago

This issue was closed due to lack of activity after being marked stale for past 7 days.

google-ml-butler[bot] commented 1 month ago

Are you satisfied with the resolution of your issue? Yes No