Closed yillkid closed 3 years ago
key_detect = 0
times=1
old_labels = ""
frame_count = 0
while (key_detect==0):
ret,image_src =cap.read()
image=cv2.resize(image_src,(224,224))
start_time = time.time()
results = classify_image(interpreter, image)
elapsed_ms = (time.time() - start_time) * 1000
label_id, prob = results[0]
if frame_count==0:
old_labels = labels[label_id]
#print("Reset old_label")
if old_labels == labels[label_id]:
frame_count = frame_count+1
#print("frame count ++")
if frame_count>=30:
print("Update old_label " + labels[label_id] + " and reset frame count")
frame_count = 0
print("frame_count:"+str(frame_count) +" tensor id:" + labels[label_id] + " and old id: " + str(old_labels))
# cv2.imshow('Detecting....',image_src)
if cv2.waitKey(1) & 0xFF == ord('q'):
key_detect = 1
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
參考程式碼:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import io
import time
import numpy as np
#import picamera
import cv2
from PIL import Image
from tflite_runtime.interpreter import Interpreter
def load_labels(path):
with open(path, 'r') as f:
return {i: line.strip() for i, line in enumerate(f.readlines())}
def set_input_tensor(interpreter, image):
tensor_index = interpreter.get_input_details()[0]['index']
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image
def classify_image(interpreter, image, top_k=1):
"""Returns a sorted array of classification results."""
set_input_tensor(interpreter, image)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = np.squeeze(interpreter.get_tensor(output_details['index']))
# If the model is quantized (uint8 data), then dequantize the results
if output_details['dtype'] == np.uint8:
scale, zero_point = output_details['quantization']
output = scale * (output - zero_point)
ordered = np.argpartition(-output, top_k)
return [(i, output[i]) for i in ordered[:top_k]]
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model', help='File path of .tflite file.', required=True)
parser.add_argument(
'--labels', help='File path of labels file.', required=True)
args = parser.parse_args()
labels = load_labels(args.labels)
interpreter = Interpreter(args.model)
interpreter.allocate_tensors()
_, height, width, _ = interpreter.get_input_details()[0]['shape']
cap = cv2.VideoCapture(0)
# 擷取畫面 寬度 設定為512
cap.set(cv2.CAP_PROP_FRAME_WIDTH,640)
# 擷取畫面 高度 設定為512
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
key_detect = 0
frame_count = 0
diff_frame_count = 0
defalut_labels = ""
while (key_detect==0):
ret,image_src =cap.read()
image=cv2.resize(image_src,(224,224))
start_time = time.time()
results = classify_image(interpreter, image)
elapsed_ms = (time.time() - start_time) * 1000
label_id, prob = results[0]
# 在 frame count 0 中定義 default label
if frame_count == 0:
default_labels = labels[label_id]
# 如果 tensor 抓到的 label != default label, 開始計次
if labels[label_id] != default_labels:
print("Hello got different, diff_frame_count : " + str(diff_frame_count))
diff_frame_count = diff_frame_count + 1
elif labels[label_id] == default_labels:
diff_frame_count = 0
# 計次 30 次後,正式取代 default label
if diff_frame_count == 10:
print("Hello change default_label from " + default_labels + " to " + labels[label_id])
default_labels = labels[label_id]
diff_frame_count = 0
time.sleep(3)
frame_count = frame_count + 1
print("Hello tensorflow got label : " + str(labels[label_id]))
print("Hello default label : " + default_labels)
if cv2.waitKey(1) & 0xFF == ord('q'):
key_detect = 1
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
karina :
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import io
import time
import numpy as np
#import picamera
import cv2
from PIL import Image
from tflite_runtime.interpreter import Interpreter
def load_labels(path):
with open(path, 'r') as f:
return {i: line.strip() for i, line in enumerate(f.readlines())}
def set_input_tensor(interpreter, image):
tensor_index = interpreter.get_input_details()[0]['index']
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image
def classify_image(interpreter, image, top_k=1):
"""Returns a sorted array of classification results."""
set_input_tensor(interpreter, image)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = np.squeeze(interpreter.get_tensor(output_details['index']))
# If the model is quantized (uint8 data), then dequantize the results
if output_details['dtype'] == np.uint8:
scale, zero_point = output_details['quantization']
output = scale * (output - zero_point)
ordered = np.argpartition(-output, top_k)
return [(i, output[i]) for i in ordered[:top_k]]
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model', help='File path of .tflite file.', required=True)
parser.add_argument(
'--labels', help='File path of labels file.', required=True)
args = parser.parse_args()
labels = load_labels(args.labels)
interpreter = Interpreter(args.model)
interpreter.allocate_tensors()
_, height, width, _ = interpreter.get_input_details()[0]['shape']
#with picamera.PiCamera(resolution=(640, 480), framerate=30) as camera:
#camera.start_preview()
cap = cv2.VideoCapture(0)
#擷取畫面 寬度 設定為512
cap.set(cv2.CAP_PROP_FRAME_WIDTH,640)
#擷取畫面 高度 設定為512
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
key_detect = 0
times=1
old_labels = ""
frame_count = 0
while (key_detect==0):
ret,image_src =cap.read()
image=cv2.resize(image_src,(224,224))
start_time = time.time()
results = classify_image(interpreter, image)
elapsed_ms = (time.time() - start_time) * 1000
label_id, prob = results[0]
if frame_count==0:
old_labels = labels[label_id]
#print("Reset old_label")
if old_labels == labels[label_id]:
frame_count = frame_count+1
#print("frame count ++")
else:
frame_count = 0
if frame_count>=10:
print("Update old_label " + labels[label_id] + " and reset frame count")
frame_count = 0
print("frame_count:"+str(frame_count) +" tensor id:" + labels[label_id] + " and old id: " + str(old_labels))
# cv2.imshow('Detecting....',image_src)
if cv2.waitKey(1) & 0xFF == ord('q'):
key_detect = 1
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
已經提交程式碼。
描述
根據原始碼:
預期目標