wuzhe71 / CPD

Code of Cascaded Partial Decoder for Fast and Accurate Salient Object Detection (CVPR2019)
276 stars 69 forks source link

关于您代码的fps的计算? #7

Closed lartpang closed 5 years ago

lartpang commented 5 years ago

最近有需求需要考虑代码的速度,自己测了下您的代码,我是按照如下的方式测试的。但是和您原文差距实在太大,想向您请教下!

我用myway的方式输出图片要比您的test_CPD.py中提供的方式要快些,~~但是也只有29.6FPS~,当然,我是直接在DUTS-TE上跑了一次,一次处理一张图,不知道和您的方式是否有差异?

NOTE

2019年07月21日21:08:56修改:我再次测试的时候,发现速度快了一些,可能和我当时电脑同时在运行其他占据了大量磁盘IO的程序有关系,这次测得了FPS:49.78818400079663

image

# -*- coding: utf-8 -*-
# @Time    : 2019/7/21 下午9:10
# @Author  : Lart Pang
# @FileName: new_fps.py
# @Project : CPD
# @GitHub  : https://github.com/lartpang

import os
import time

import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

from model.CPD_ResNet_models import CPD_ResNet

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.cuda.empty_cache()
torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False

def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

class FPS():
    def __init__(self, proj_name, args):
        super(FPS, self).__init__()
        self.args = args
        self.to_pil = transforms.ToPILImage()
        self.proj_name = proj_name
        self.dev = torch.device("cuda:0")
        self.net = self.args[proj_name]['net']().to(self.dev)
        self.net.eval()

        self.test_img_transform = transforms.Compose([
            transforms.Resize((self.args['crop_size'], self.args['crop_size'])),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def get_fps(self, data_path, save_path):
        print(f'保存路径为{save_path}')
        check_mkdir(save_path)

        print(f'开始测试...{data_path}')
        img_path = os.path.join(data_path, 'Image')
        img_list = os.listdir(img_path)

        start_time = time.time()
        tqdm_iter = tqdm(enumerate(img_list), total=len(img_list), leave=False)
        for idx, img_name in tqdm_iter:
            tqdm_iter.set_description(f"{self.proj_name}:te=>{idx + 1}")
            img_fullpath = os.path.join(img_path, img_name)
            test_data = Image.open(img_fullpath).convert('RGB')
            img_size = test_data.size
            test_data = self.test_img_transform(test_data)
            test_data = test_data.unsqueeze(0)

            inputs = test_data.to(self.dev)
            with torch.no_grad():
                _, res = self.net(inputs)
            # res = F.interpolate(res, size=img_size, mode='bilinear', align_corners=False)
            # res = res.sigmoid().data.cpu().numpy().squeeze()
            # res = (res - res.min()) / (res.max() - res.min() + 1e-8) * 255
            # oimg_path = os.path.join(save_path, img_name[:-4] + '.png')
            # imwrite(oimg_path, res.astype(numpy.uint8))

            # myway 29.6
            res = res.sigmoid().cpu().detach().squeeze(0)
            res = self.to_pil(res).resize(img_size)
            oimg_path = os.path.join(save_path, img_name[:-4] + '.png')
            res.save(oimg_path)

        total_time = time.time() - start_time
        fps = len(img_list) / total_time
        return fps

if __name__ == '__main__':
    proj_list = ['CPD']

    data_dicts = {
        'duts': '/home/lart/Datasets/RGBSaliency/DUTS/Test',
    }

    arg_dicts = {
        'CPD'      : {
            'net'     : CPD_ResNet,
            'exp_name': 'CPD_ResNet'
        },
        'crop_size': 352,
    }

    result = {}
    for proj_name in proj_list:
        result[arg_dicts[proj_name]['exp_name']] = {}

        for data_name, data_path in data_dicts.items():
            save_path = (f"/home/lart/Coding/CPD/output/"
                         f"{arg_dicts[proj_name]['exp_name']}/pre/{data_name}")
            fpser = FPS(proj_name, arg_dicts)
            fps = fpser.get_fps(data_path, save_path)
            print(f"FPS:{fps}")
            del fpser
    print('测试完毕')
wuzhe71 commented 5 years ago

@lartpang 我在paper中列的都是forward的时间,你比较的时候统一测试方式就行。