CVI-SZU / CLIMS

[CVPR 2022] CLIMS: Cross Language Image Matching for Weakly Supervised Semantic Segmentation
MIT License
124 stars 12 forks source link

test time #26

Open Sunny599 opened 1 year ago

Sunny599 commented 1 year ago

您好,非常感谢您的代码,为我的工作提供了很多帮助。 请问使用您提供的deeplabv2以及后处理的代码,在coco2014 val数据集上测试大概花费多长时间?

Sierkinhane commented 1 year ago

@Tiiiktak

Tiiiktak commented 1 year ago

Sorry for the late reply.

对于coco14 val 集,在 test 阶段我们花费了约2.5h

在 crf 阶段,我们重新实现了多进程部分并得到 crf_coco.py。 当 n_jobs=16 ,用时约30h

你可以调整参数 --n jobs 寻求进一步提速

crf_coco.py:

import torch 
import numpy as np 
import time 
import os 
import torch.nn.functional as F
import multiprocessing as mp
from multiprocessing import Process
from omegaconf import OmegaConf
import json 
import argparse
from tqdm import tqdm

from libs.utils import DenseCRF, PolynomialLR, scores

from main_v2 import get_dataset, makedirs

def process_crf(i, dataset, logit_dir, postprocessor):
    image_id, image, gt_label = dataset.__getitem__(i)
    filename = os.path.join(logit_dir, image_id + ".npy")
    logit = np.load(filename)
    _, H, W = image.shape
    logit = torch.FloatTensor(logit)[None, ...]
    logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
    prob = F.softmax(logit, dim=1)[0].numpy()

    image = image.astype(np.uint8).transpose(1, 2, 0)
    prob = postprocessor(image, prob)
    label = np.argmax(prob, axis=0)
    return label, gt_label

def crf(dataset, logit_dir, postprocessor, num_workers=4):
    print("CRF post-processing")
    pbar = tqdm(total=len(dataset), desc="CRF post-processing", ascii=True)
    def update(*a):
        pbar.update()
    pool = mp.Pool(num_workers)
    results = []
    for i in range(len(dataset)):
        results.append(pool.apply_async(process_crf, 
        args=(i, dataset, logit_dir, postprocessor),
        callback=update))
    pool.close()
    pool.join()
    results = [r.get() for r in results]

    print("CRF post-processing finished")
    # print("Results:", results)
    return results

def main(config_path, n_jobs):
    # Configuration
    CONFIG = OmegaConf.load(config_path)
    torch.set_grad_enabled(False)
    print("# jobs:", n_jobs)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )

    # Path to logit files
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    print("Logit src:", logit_dir)
    if not os.path.isdir(logit_dir):
        print("Logit not found, run first: python main.py test [OPTIONS]")
        quit()

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores_crf_coco.json")
    print("Score dst:", save_path)

    # CRF
    results = crf(dataset, logit_dir, postprocessor, num_workers=n_jobs)
    # Evaluation
    preds, gts = zip(*results)

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
    print(f'mIoU: {score["Mean IoU"]}')
    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config_path", type=str)
    parser.add_argument("--n_jobs", type=int, default=4)
    args = parser.parse_args()

    main(args.config_path, args.n_jobs)