facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.55k stars 7.49k forks source link

PointRend: No object named 'PointRendMaskHead' found in 'ROI_MASK_HEAD' registry! #2611

Closed cateberry closed 3 years ago

cateberry commented 3 years ago

I am trying to fine-tune a PointRend model on a new dataset, but when I run the training script I get the error: KeyError: "No object named 'PointRendMaskHead' found in 'ROI_MASK_HEAD' registry!"

  1. Training script (mostly taken from train_net.py):
    
    import detectron2.data.transforms as T
    import detectron2.utils.comm as comm
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from detectron2.data import DatasetMapper, MetadataCatalog, build_detection_train_loader
    from detectron2.engine import default_argument_parser, default_setup, launch
    from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    SemSegEvaluator,
    verify_results,
    )
    from detectron2.projects.point_rend import add_pointrend_config

class Trainer(DefaultTrainer): """ We use the "DefaultTrainer" which contains a number pre-defined logic for standard training workflow. They may not work for you, especially if you are working on a new research project. In that case you can use the cleaner "SimpleTrainer", or write your own training loop. """

@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
    """
    Create evaluator(s) for a given dataset.
    This uses the special metadata "evaluator_type" associated with each builtin dataset.
    For your own dataset, you can simply create an evaluator manually in your
    script and do not have to worry about the hacky if-else logic here.
    """
    if output_folder is None:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
    evaluator_list = []
    evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
    if evaluator_type == "lvis":
        return LVISEvaluator(dataset_name, output_dir=output_folder)
    if evaluator_type == "coco":
        return COCOEvaluator(dataset_name, output_dir=output_folder)
    if evaluator_type == "cityscapes_instance":
        assert (
            torch.cuda.device_count() >= comm.get_rank()
        ), "CityscapesEvaluator currently do not work with multiple machines."
        return CityscapesInstanceEvaluator(dataset_name)
    if len(evaluator_list) == 0:
        raise NotImplementedError(
            "no Evaluator for the dataset {} with the type {}".format(
                dataset_name, evaluator_type
            )
        )
    if len(evaluator_list) == 1:
        return evaluator_list[0]
    return DatasetEvaluators(evaluator_list)

@classmethod
def build_train_loader(cls, cfg):
    if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
        mapper = DatasetMapper(cfg, is_train=True, augmentations=build_sem_seg_train_aug(cfg))
    else:
        mapper = None
    return build_detection_train_loader(cfg, mapper=mapper)

def main(): cfg = get_cfg()

add_pointrend_config(cfg)

cfg.merge_from_file('/content/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml') 
cfg.DATASETS.TRAIN = ("fibre_dataset",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_3c3198.pkl"
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 10    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 
cfg.freeze()

trainer = Trainer(cfg)
trainer.resume_or_load(resume=False)
return trainer.train()

main()

2. Full error:

KeyError Traceback (most recent call last)

in () ----> 1 main() 13 frames in main() 81 cfg.freeze() 82 ---> 83 trainer = Trainer(cfg) 84 trainer.resume_or_load(resume=False) 85 return trainer.train() /usr/local/lib/python3.6/dist-packages/detectron2/engine/defaults.py in __init__(self, cfg) 280 281 # Assume these objects must be constructed in this order. --> 282 model = self.build_model(cfg) 283 optimizer = self.build_optimizer(cfg, model) 284 data_loader = self.build_train_loader(cfg) /usr/local/lib/python3.6/dist-packages/detectron2/engine/defaults.py in build_model(cls, cfg) 432 Overwrite it if you'd like a different model. 433 """ --> 434 model = build_model(cfg) 435 logger = logging.getLogger(__name__) 436 logger.info("Model:\n{}".format(model)) /usr/local/lib/python3.6/dist-packages/detectron2/modeling/meta_arch/build.py in build_model(cfg) 19 """ 20 meta_arch = cfg.MODEL.META_ARCHITECTURE ---> 21 model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 22 model.to(torch.device(cfg.MODEL.DEVICE)) 23 return model /usr/local/lib/python3.6/dist-packages/detectron2/config/config.py in wrapped(self, *args, **kwargs) 179 180 if _called_with_cfg(*args, **kwargs): --> 181 explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) 182 init_func(self, **explicit_args) 183 else: /usr/local/lib/python3.6/dist-packages/detectron2/config/config.py in _get_args_from_config(from_config_func, *args, **kwargs) 234 if name not in supported_arg_names: 235 extra_kwargs[name] = kwargs.pop(name) --> 236 ret = from_config_func(*args, **kwargs) 237 # forward the other arguments to __init__ 238 ret.update(extra_kwargs) /usr/local/lib/python3.6/dist-packages/detectron2/modeling/meta_arch/rcnn.py in from_config(cls, cfg) 77 "backbone": backbone, 78 "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), ---> 79 "roi_heads": build_roi_heads(cfg, backbone.output_shape()), 80 "input_format": cfg.INPUT.FORMAT, 81 "vis_period": cfg.VIS_PERIOD, /usr/local/lib/python3.6/dist-packages/detectron2/modeling/roi_heads/roi_heads.py in build_roi_heads(cfg, input_shape) 41 """ 42 name = cfg.MODEL.ROI_HEADS.NAME ---> 43 return ROI_HEADS_REGISTRY.get(name)(cfg, input_shape) 44 45 /usr/local/lib/python3.6/dist-packages/detectron2/config/config.py in wrapped(self, *args, **kwargs) 179 180 if _called_with_cfg(*args, **kwargs): --> 181 explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) 182 init_func(self, **explicit_args) 183 else: /usr/local/lib/python3.6/dist-packages/detectron2/config/config.py in _get_args_from_config(from_config_func, *args, **kwargs) 234 if name not in supported_arg_names: 235 extra_kwargs[name] = kwargs.pop(name) --> 236 ret = from_config_func(*args, **kwargs) 237 # forward the other arguments to __init__ 238 ret.update(extra_kwargs) /usr/local/lib/python3.6/dist-packages/detectron2/modeling/roi_heads/roi_heads.py in from_config(cls, cfg, input_shape) 557 ret.update(cls._init_box_head(cfg, input_shape)) 558 if inspect.ismethod(cls._init_mask_head): --> 559 ret.update(cls._init_mask_head(cfg, input_shape)) 560 if inspect.ismethod(cls._init_keypoint_head): 561 ret.update(cls._init_keypoint_head(cfg, input_shape)) /usr/local/lib/python3.6/dist-packages/detectron2/modeling/roi_heads/roi_heads.py in _init_mask_head(cls, cfg, input_shape) 630 else: 631 shape = {f: input_shape[f] for f in in_features} --> 632 ret["mask_head"] = build_mask_head(cfg, shape) 633 return ret 634 /usr/local/lib/python3.6/dist-packages/detectron2/modeling/roi_heads/mask_head.py in build_mask_head(cfg, input_shape) 288 """ 289 name = cfg.MODEL.ROI_MASK_HEAD.NAME --> 290 return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape) /usr/local/lib/python3.6/dist-packages/fvcore/common/registry.py in get(self, name) 70 if ret is None: 71 raise KeyError( ---> 72 "No object named '{}' found in '{}' registry!".format(name, self._name) 73 ) 74 return ret KeyError: "No object named 'PointRendMaskHead' found in 'ROI_MASK_HEAD' registry!" ``` I am running this code in Google colab, mostly using pre-existing code from official tutorials for importing dependencies, and I haven't edited any of the base files in the point_rend project directory. Am I missing something?
github-actions[bot] commented 3 years ago

You've chosen to report an unexpected problem or bug. Unless you already know the root cause of it, please include details about it by filling the issue template. The following information is missing: "Instructions To Reproduce the Issue and Full Logs";

ppwwyyxx commented 3 years ago

Your PointRend configs seem to come from latest github master. Please use PointRend at v0.3 if you're using detectron2 v0.3.