google / automl

Google Brain AutoML
Apache License 2.0
6.24k stars 1.45k forks source link

Test speed code available. you can modify it for predicting your own image. #42

Closed YanqingWu closed 4 years ago

YanqingWu commented 4 years ago

I have solved it by myself, besides I have tested the speed of d0 is faster than yolov3 (same img size, torch version) in 2080Ti。 blew is the source code, maybe you can test for your device, everything you need is download the checkpoint (must) and put it under the efficientdet dir:

1.first step: Download checkpoint.

cd efficientdet ipython MODEL = 'efficientdet-d0' #@param !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/{MODEL}.tar.gz !tar zxf {MODEL}.tar.gz

change the MODEL for d1, d2 ..., repeat

2.second step

cd efficientdet touch test.py create a new test.py file, and write the blew code, then , python test.py --model 0

3.you can also modify the code as your predict.py for your own image, other than fake image

---------------------------------- source code ---------------------------

import utils
import argparse
import numpy as np
import hparams_config
import efficientdet_arch
import tensorflow.compat.v1 as tf

def test(imgs, model_name, img_size=512):
    with tf.Session() as sess:
        X = tf.placeholder(tf.float32, shape=(1, img_size, img_size, 3))
        class_outputs, box_outputs = efficientdet_arch.efficientdet(X, model_name=model_name)
        sess.run(tf.global_variables_initializer())
        if tf.io.gfile.isdir(model_name):
            model_name = tf.train.latest_checkpoint(model_name)

        var_dict = utils.get_ema_vars()
        tf.train.get_or_create_global_step()
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(var_dict, max_to_keep=1)
        saver.restore(sess, model_name)
        import time
        times = []
        for img in imgs:
            img = img[np.newaxis, ...]
            start = time.time()
            sess.run(class_outputs, feed_dict={X: img})
            spent = time.time() - start
            print(spent)
            times.append(spent)
        print('mean time of 99 times: %.4f' % np.array(times[1:]).mean())

if __name__ == '__main__':
    parser = argparse.ArgumentParser('test')
    parser.add_argument('--model', type=str, default='0')
    args = parser.parse_args()
    model = 'efficientdet-d%s' % args.model
    img_size = hparams_config.efficientdet_model_param_dict[model]['image_size']
    test(np.random.randn(100, img_size, img_size, 3), model, img_size=img_size)
YanqingWu commented 4 years ago

I am not very familiar with tensorflow, maybe there are some redundancy code, if then, please @me. thanks.

WonTaeYeon commented 4 years ago

Thx, Your code is very useful.

mingxingtan commented 4 years ago

Thanks @YanqingWu

If you want to benchmark the model latency, you can try this command:

python model_inspect.py --model_name=efficientdet-d0 --runmode=bm --bm_runs=100

If you want to predict a single image, here is a simple utility:

https://github.com/google/automl/blob/7768c499ede4a732fde022a24d8525b3396b50a4/efficientdet/inference.py#L240

mingxingtan commented 4 years ago

Example code is like this:

    driver = inference.ServingDriver('efficientdet-d0', '/tmp/efficientdet-d0')
    driver.build()
    for f in tf.io.gfile.glob('/tmp/*.jpg'):
      image = Image.open(f)
      predictions = driver.serve(image)
      out_image = driver.visualize(image, predictions[0])

I am going to close this issue for now, but if you have further question, please feel free to reopen it. Thanks!

cyx6666 commented 4 years ago

Example code is like this:

    driver = inference.ServingDriver('efficientdet-d0', '/tmp/efficientdet-d0')
    driver.build()
    for f in tf.io.gfile.glob('/tmp/*.jpg'):
      image = Image.open(f)
      predictions = driver.serve(image)
      out_image = driver.visualize(image, predictions[0])

I am going to close this issue for now, but if you have further question, please feel free to reopen it. Thanks!

I got this error

RuntimeError                              Traceback (most recent call last)
<ipython-input-48-4c96731ccaf1> in <module>()
      3 import tensorflow.compat.v1 as tf
      4 driver = inference.ServingDriver('efficientdet-d4', 'trained_model')
----> 5 driver.build()
      6 for f in tf.io.gfile.glob('test/*.jpg'):
      7   image = Image.open(f)

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in placeholder(dtype, shape, name)
   3021   """
   3022   if context.executing_eagerly():
-> 3023     raise RuntimeError("tf.placeholder() is not compatible with "
   3024                        "eager execution.")
   3025 

RuntimeError: tf.placeholder() is not compatible with eager execution.
wangtianlong1994 commented 4 years ago

示例代码如下:

    driver = inference.ServingDriver('efficientdet-d0', '/tmp/efficientdet-d0')
    driver.build()
    for f in tf.io.gfile.glob('/tmp/*.jpg'):
      image = Image.open(f)
      predictions = driver.serve(image)
      out_image = driver.visualize(image, predictions[0])

我将暂时关闭此问题,但是如果您还有其他问题,请随时重新打开。谢谢!

我收到这个错误

RuntimeError                              Traceback (most recent call last)
<ipython-input-48-4c96731ccaf1> in <module>()
      3 import tensorflow.compat.v1 as tf
      4 driver = inference.ServingDriver('efficientdet-d4', 'trained_model')
----> 5 driver.build()
      6 for f in tf.io.gfile.glob('test/*.jpg'):
      7   image = Image.open(f)

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in placeholder(dtype, shape, name)
   3021   """
   3022   if context.executing_eagerly():
-> 3023     raise RuntimeError("tf.placeholder() is not compatible with "
   3024                        "eager execution.")
   3025 

RuntimeError: tf.placeholder() is not compatible with eager execution.

大哥 我也遇到了同样的错误 请问你解决了吗?

mingxingtan commented 4 years ago

Sorry for the issue. How about adding tf.compat.v1.disable_v2_behavior() at the beginning of your code?

wangtianlong1994 commented 4 years ago

对不起,这个问题。在代码开头添加tf.compat.v1.disable_v2_behavior()怎么样?

Very effective

varghesealex90 commented 4 years ago

one question, this code only takes into account the forward pass right. Does it also take into account the post_processing?