margaretmm / myblog

my tech blog
0 stars 1 forks source link

classify.py #18

Open margaretmm opened 6 years ago

margaretmm commented 6 years ago

-- coding: utf-8 --

import platform

import tensorflow as tf import numpy as np import os

TRAINED_MODEL_FILE='trained.pb' TRAINED_MODEL_FILE_ori='classify_image_graph_def.pb'

if platform.system() in ['Linux']:

下载的谷歌训练好的Inception-v3模型文件目录

MODEL_DIR = 'model/'
TRAINED_MODEL_DIR='modelTrained/'
TEST_IMG='data/test/'

else: import cv2 MODEL_DIR = 'D:\05_PycharmProjects\test\inceptionV3\model\' TRAINED_MODEL_DIR='D:\05_PycharmProjects\test\inceptionV3\modelTrained\' TEST_IMG = 'D:\05_PycharmProjects\test\inceptionV3\data\test\'

def predict(_): strings = ['blackScreen', 'normalScreen','other']

def id_to_string(node_id):
    return strings[node_id]

#with tf.gfile.FastGFile(TRAINED_MODEL_DIR+TRAINED_MODEL_FILE_ori, 'rb') as f:
with tf.gfile.FastGFile(TRAINED_MODEL_DIR+TRAINED_MODEL_FILE, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('output/prob:0')
    # 遍历目录
    for root, dirs, files in os.walk(TEST_IMG):
        for file in files:
            # 载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
            predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})  # 图片格式是jpg格式
            predictions = np.squeeze(predictions)  # 把结果转为1维数据

            # 打印图片路径及名称
            image_path = os.path.join(root, file)
            print(image_path)

            # 排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:
                # 获取分类名称
                human_string = id_to_string(node_id)
                # 获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()
            if platform.system() in ['Linux']:
                return
            else:
                img = cv2.imread(image_path)
                cv2.imshow('image', img)
                cv2.waitKey(0)
    if platform.system() in ['Linux']:
        return
    else:
        cv2.destroyAllWindows()

if name == 'main': tf.app.run(predict)