Open Sunny599 opened 1 year ago
@Tiiiktak
Sorry for the late reply.
对于coco14 val 集,在 test 阶段我们花费了约2.5h
在 crf 阶段,我们重新实现了多进程部分并得到 crf_coco.py
。
当 n_jobs=16
,用时约30h
你可以调整参数 --n jobs
寻求进一步提速
crf_coco.py
:
import torch
import numpy as np
import time
import os
import torch.nn.functional as F
import multiprocessing as mp
from multiprocessing import Process
from omegaconf import OmegaConf
import json
import argparse
from tqdm import tqdm
from libs.utils import DenseCRF, PolynomialLR, scores
from main_v2 import get_dataset, makedirs
def process_crf(i, dataset, logit_dir, postprocessor):
image_id, image, gt_label = dataset.__getitem__(i)
filename = os.path.join(logit_dir, image_id + ".npy")
logit = np.load(filename)
_, H, W = image.shape
logit = torch.FloatTensor(logit)[None, ...]
logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
prob = F.softmax(logit, dim=1)[0].numpy()
image = image.astype(np.uint8).transpose(1, 2, 0)
prob = postprocessor(image, prob)
label = np.argmax(prob, axis=0)
return label, gt_label
def crf(dataset, logit_dir, postprocessor, num_workers=4):
print("CRF post-processing")
pbar = tqdm(total=len(dataset), desc="CRF post-processing", ascii=True)
def update(*a):
pbar.update()
pool = mp.Pool(num_workers)
results = []
for i in range(len(dataset)):
results.append(pool.apply_async(process_crf,
args=(i, dataset, logit_dir, postprocessor),
callback=update))
pool.close()
pool.join()
results = [r.get() for r in results]
print("CRF post-processing finished")
# print("Results:", results)
return results
def main(config_path, n_jobs):
# Configuration
CONFIG = OmegaConf.load(config_path)
torch.set_grad_enabled(False)
print("# jobs:", n_jobs)
# Dataset
dataset = get_dataset(CONFIG.DATASET.NAME)(
root=CONFIG.DATASET.ROOT,
split=CONFIG.DATASET.SPLIT.VAL,
ignore_label=CONFIG.DATASET.IGNORE_LABEL,
mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
augment=False,
)
print(dataset)
# CRF post-processor
postprocessor = DenseCRF(
iter_max=CONFIG.CRF.ITER_MAX,
pos_xy_std=CONFIG.CRF.POS_XY_STD,
pos_w=CONFIG.CRF.POS_W,
bi_xy_std=CONFIG.CRF.BI_XY_STD,
bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
bi_w=CONFIG.CRF.BI_W,
)
# Path to logit files
logit_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"features",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
"logit",
)
print("Logit src:", logit_dir)
if not os.path.isdir(logit_dir):
print("Logit not found, run first: python main.py test [OPTIONS]")
quit()
# Path to save scores
save_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"scores",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
)
makedirs(save_dir)
save_path = os.path.join(save_dir, "scores_crf_coco.json")
print("Score dst:", save_path)
# CRF
results = crf(dataset, logit_dir, postprocessor, num_workers=n_jobs)
# Evaluation
preds, gts = zip(*results)
# Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
print(f'mIoU: {score["Mean IoU"]}')
with open(save_path, "w") as f:
json.dump(score, f, indent=4, sort_keys=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--n_jobs", type=int, default=4)
args = parser.parse_args()
main(args.config_path, args.n_jobs)
您好,非常感谢您的代码,为我的工作提供了很多帮助。 请问使用您提供的deeplabv2以及后处理的代码,在coco2014 val数据集上测试大概花费多长时间?