ZJULearning / resa

Implementation of our paper 'RESA: Recurrent Feature-Shift Aggregator for Lane Detection' in AAAI2021.
Apache License 2.0
175 stars 36 forks source link

Problem of Refering inference (detect.py) #19

Closed KnightHarute closed 3 years ago

KnightHarute commented 3 years ago

Thanks for the latest reply. I just checked the detect.py underTuroad/lanedet project. The detect.py works perfectly in that project. However, when I simply copy-pasted it into this project, and use culane.py as a config file. Below is the detect.py code. Basically, I just deleted self.processes = Process(cfg.val_process, cfg) and data = self.processes(data) these two lines. Because I didn't change the config file culane.py, and there is no val_process in the config file.

`import numpy as np import torch import cv2 import os import os.path as osp import glob import argparse from lanedet.datasets.process import Process from lanedet.models.registry import build_net from lanedet.utils.config import Config from lanedet.utils.visualization import imshow_lanes from lanedet.utils.net_utils import load_network from pathlib import Path

class Detect(object): def init(self, cfg): self.cfg = cfg

self.processes = Process(cfg.val_process, cfg)

    self.net = build_net(self.cfg)
    self.net = torch.nn.parallel.DataParallel(
            self.net, device_ids = range(1)).cuda()
    self.net.eval()
    load_network(self.net, self.cfg.load_from)

def preprocess(self, img_path):
    ori_img = cv2.imread(img_path)
    img = ori_img[self.cfg.cut_height:, :, :].astype(np.float32)
    data = {'img': img}
    # data = self.processes(data)
    data['img'] = data['img'].unsqueeze(0)
    data.update({'img_path':img_path, 'ori_img':ori_img})
    return data

def inference(self, data):
    with torch.no_grad():
        data = self.net(data)
    return data

def show(self, data):
    out_file = self.cfg.savedir 
    if out_file:
        out_file = osp.join(out_file, osp.basename(data['img_path']))
    lanes = [lane.to_array(self.cfg) for lane in data['lanes']]
    imshow_lanes(data['ori_img'], lanes, show=self.cfg.show, out_file=out_file)

def run(self, data):
    data = self.preprocess(data)
    data['lanes'] = self.inference(data)[0]
    if self.cfg.show or self.cfg.savedir:
        self.show(data)
    return data

def get_img_paths(path): p = str(Path(path).absolute()) # os-agnostic absolute path if '' in p: paths = sorted(glob.glob(p, recursive=True)) # glob elif os.path.isdir(p): paths = sorted(glob.glob(os.path.join(p, '.*'))) # dir elif os.path.isfile(p): paths = [p] # files else: raise Exception(f'ERROR: {p} does not exist') return paths

def process(args): cfg = Config.fromfile(args.config) cfg.show = args.show cfg.savedir = args.savedir cfg.load_from = args.load_from detect = Detect(cfg) paths = get_img_paths(args.img) for p in paths: detect.run(p)

if name == 'main': parser = argparse.ArgumentParser() parser.add_argument('config', help='The path of config file') parser.add_argument('--img', help='The path of the img (img file or img_folder), for example: data/*.png') parser.add_argument('--show', action='store_true', help='Whether to show the image') parser.add_argument('--savedir', type=str, default=None, help='The root of save directory') parser.add_argument('--load_from', type=str, default='best.pth', help='The path of model') args = parser.parse_args() process(args) `

It shows an error (venv) F:\code\ZJULearning-resa\resa-main>python tools\detect.py configs\culane.py --img save_3\ --load_from culane_resnet50.pth --savedir out_3_test\ Traceback (most recent call last): File "tools\detect.py", line 88, in <module> process(args) File "tools\detect.py", line 74, in process detect = Detect(cfg) File "tools\detect.py", line 19, in __init__ self.net = build_net(self.cfg) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 32, in build_net return build(cfg.net, NET, default_args=dict(cfg=cfg)) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 16, in build return build_from_cfg(cfg, registry, default_args) File "f:\code\lanedet\lanedet-main\lanedet\utils\registry.py", line 71, in build_from_cfg raise KeyError('{} is not in the {} registry'.format( KeyError: 'RESANet is not in the net registry' Then I tried to change the type='RESANet' into type=Detector, just like the config file from Turoad/lanedet project.

And it shows another error: (venv) F:\code\ZJULearning-resa\resa-main>python tools\detect.py configs\culane.py --img save_3\ --load_from culane_resnet50.pth --savedir out_3_test\ pretrained model: https://download.pytorch.org/models/resnet50-19c8e357.pth Traceback (most recent call last): File "tools\detect.py", line 88, in <module> process(args) File "tools\detect.py", line 74, in process detect = Detect(cfg) File "tools\detect.py", line 19, in __init__ self.net = build_net(self.cfg) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 32, in build_net return build(cfg.net, NET, default_args=dict(cfg=cfg)) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 16, in build return build_from_cfg(cfg, registry, default_args) File "f:\code\lanedet\lanedet-main\lanedet\utils\registry.py", line 81, in build_from_cfg return obj_cls(**args) File "f:\code\lanedet\lanedet-main\lanedet\models\net\detector.py", line 13, in __init__ self.backbone = build_backbone(cfg) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 20, in build_backbone return build(cfg.backbone, BACKBONE, default_args=dict(cfg=cfg)) File "f:\code\lanedet\lanedet-main\lanedet\models\registry.py", line 16, in build return build_from_cfg(cfg, registry, default_args) File "f:\code\lanedet\lanedet-main\lanedet\utils\registry.py", line 81, in build_from_cfg return obj_cls(**args) File "f:\code\lanedet\lanedet-main\lanedet\models\backbone\resnet.py", line 144, in __init__ out_channel * self.model.expansion, cfg.featuremap_out_channel) File "f:\code\lanedet\lanedet-main\lanedet\utils\config.py", line 327, in __getattr__ return getattr(self._cfg_dict, name) File "f:\code\lanedet\lanedet-main\lanedet\utils\config.py", line 40, in __getattr__ raise ex AttributeError: 'ConfigDict' object has no attribute 'featuremap_out_channel'

Could you help me out, please?

Turoad commented 3 years ago

This refering is in project: https://github.com/Turoad/lanedet. Please try the detect.py in that project.