tensorflow / models

Models and examples built with TensorFlow
Other
77.18k stars 45.75k forks source link

MobileDet CPU doesn't support QAT training #8734

Open Apollo-XI opened 4 years ago

Apollo-XI commented 4 years ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

1. The entire URL of the file you are using

https://github.com/tensorflow/models/tree/master/research/object_detection

2. Describe the bug

MobileDet CPU doesn't support QAT training.

3. Steps to reproduce

Just add QAT in Coco training file (ssdlite_mobiledet_cpu_320x320_coco_sync_4x4.config)

graph_rewriter {
  quantization {
    delay: 50
    activation_bits: 8
    weight_bits: 8
  }

Then, it fails after saving the first training checkpoint after starting doing the quantization. Exactly, when it tries to restore the model from the checkpoint.

tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key FeatureExtractor/MobileDetCPU/Conv/conv_quant/max not found in checkpoint
     [[node save/RestoreV2 (defined at /tensorflow-1.15.2/python3.6/tensorflow_core/python/framework/ops.py:1748) ]]

4. Expected behavior

It should work.

5. Additional context

Also, this checkpoint fails if you try to convert to inference graph with the python object_detection/export_inference_graph.py or python object_detection/export_tflite_ssd_graph.py command. It fails with the same error:

tensorflow.python.framework.errors_impl.NotFoundError: Key FeatureExtractor/MobileDetCPU/Conv/conv_quant/max not found in checkpoint
     [[node save/RestoreV2 (defined at /Users/vferrer/miniconda3/envs/Pytorch/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Neither works disable checkpoints, performing the full training. It fails with the same error in python object_detection/export_inference_graph.py or python object_detection/export_tflite_ssd_graph.py and at the end of the training, when it tries to restore the model from the checkpoint.

6. System information

bmount commented 4 years ago

I have the same issue with MNAS FPN + MobileNet v2. (Essentially the same quantization block as OP added to the reference MNAS config.) The specific missing nodes are:

  (0) Not found: Key FeatureExtractor/MnasFPN/cell_0/add/activation_AddV2_quant/max not found in checkpoint
     [[node save/RestoreV2 (defined at /home/mlft/fused/venv3/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]
  (1) Not found: Key FeatureExtractor/MnasFPN/cell_0/add/activation_AddV2_quant/max not found in checkpoint
     [[node save/RestoreV2 (defined at /home/mlft/fused/venv3/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]
     [[save/RestoreV2/_301]]

I think this may be related: https://github.com/tensorflow/models/issues/8481

I've been looking for anything that might be filtering the saved parameters. I also think there may be a necessary step converting some delayed_quant to the final inference graph, or a tf1-2 issue related specifically to AddV2?

Apollo-XI commented 4 years ago

I think this may be related too: #8445 . It seems that QAT training is currently broken in Object Detection api for, at least, mobilenet variants :(

Apollo-XI commented 4 years ago

After some time, I manage to train it by setting inplace_batchnorm_update: false. So, I need to perform quantization aware training the GPU instead of in TPU pod.

This solution is suggested here https://github.com/tensorflow/models/issues/8331#issuecomment-611279484 for quantization aware training for MobileNetV3.

holyhao commented 4 years ago

@Apollo-XI I meet the same issuse,anything new here. Can you train and convert it to tflite sucessfully?

Apollo-XI commented 4 years ago

No. I ended up training on GPU and converting to tflite. I didn't have any problems to convert to TFLite. I used this command:

tflite_convert \
    --graph_def_file=tflite_graph.pb \
    --output_file=detect.tflite \
    --input_shapes=1,320,320,3 \
    --input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
    --inference_type=QUANTIZED_UINT8 \
    --inference_input_type=QUANTIZED_UINT8 \
    --allow_custom_ops \
    --mean_values=128 \
    --std_dev_values=128 \
    --change_concat_input_ranges=false