sacmehta / ESPNet

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation
https://sacmehta.github.io/ESPNet/
MIT License
542 stars 112 forks source link

Inference speed measurement #57

Closed Jason93K closed 5 years ago

Jason93K commented 5 years ago

Hello, @sacmehta Thank you for your impressive work. However, I don't know how to measure the inference speed of the network properly. I created the [1, 3, 512, 1024] size of input and tried to measure the only model execution time except the first iteration. When I used torch.cuda.synchronize() or cudnn.benchmark=True which are used in eval_forwardTime.py of ERFNet, the speed was not even close to 112. Can you share the codes that you used to measure the inference speed?

P.S.: I use python 3.7, PyTorch 0.4.1, CUDA 9.0, cudnn 7.1, Titan X gpu.

yun-liu commented 5 years ago

I met the same problem, and there seems to be the same problem in other small networks https://github.com/wutianyiRosun/CGNet/issues/2

torch.cuda.synchronize() is important to measure the correct speed for PyTorch.

My code is

model.eval()
for idx in range(300):
    input = Variable(torch.rand([3, 512, 1024]).unsqueeze(0), requires_grad=False).cuda(0)
    start_time = time.time()
    out = model(input)
    torch.cuda.synchronize()
    time_taken = time.time() - start_time
    print("Run-Time: %.4f s" % time_taken)

The maximum speed of ESPNet is 61.0fps (for 512 x 1024 RGB images) on a TITAN Xp GPU.

sacmehta commented 5 years ago

Try this code. Please feel free to make changes

import numpy as np
import torch

def computeTime(model, device='cuda'):
    inputs = torch.randn(1, 3, 512, 1024)
    if device == 'cuda':
        model = model.cuda()
        inputs = inputs.cuda()

    model.eval()

    i = 0
    time_spent = []
    while i < 100:
        start_time = time.time()
        with torch.no_grad():
            _ = model(inputs)

        if device == 'cuda':
            torch.cuda.synchronize()  # wait for cuda to finish (cuda is asynchronous!)
        if i != 0:
            time_spent.append(time.time() - start_time)
        i += 1
    print('Avg execution time (ms): {:.3f}'.format(np.mean(time_spent)))

model = ESPNet()
computeTime(model)
yun-liu commented 5 years ago

@sacmehta Thank you very much for your reply! With your code, I find the runtime for a 512 x 1024 input is between 0.010 second and 0.014 second on a TITAN Xp GPU, close to your results. Maybe the runtime fluctuation is because of the PyTorch framework. Thanks again!

Jason93K commented 5 years ago

@yun-liu Thank you for letting me know another network issue. @sacmehta Thank you for replying and sharing your code. I found ESPNet took about 0.00824 seconds on Titan Xp gpu, but it took about 0.01406 seconds on Titan X gpu...

sacmehta commented 5 years ago

For benchmarking, Please make sure that no other tasks are running on your machine (on both CPU and GPU) while measuring inference time.

pyradd commented 4 years ago

@sacmehta why does print('Avg execution time (ms): {:.3f}'.format(np.mean(time_spent))) has an "ms" in this line? Isnt that execution time in seconds?