facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.02k stars 7.42k forks source link

PointRend training error with no object image #4383

Open jasonchenPJ opened 2 years ago

jasonchenPJ commented 2 years ago

I add the no object image to train PointRend model, but sometime i get the runtimeError even set FILTER_EMPTY_ANNOTATIONS=False. I guess this error occurs because all image with no object are sampled in one batch. So i build the full no object dataset to verify this concept and got the following results.

The following link is my dataset: https://drive.google.com/file/d/1sJlKhd6InYEozaN4F2mF7LxeRRPaSV8R/view?usp=sharing

Is there a solution to the problem?

Instructions To Reproduce the Issue:

This is a PointRend variant of Detectron2, using a custom dataset and adding image augmentation.

  1. Full runnable code or full changes you made:
    
    import os
    import cv2
    import detectron2.data.transforms as T
    import detectron2.utils.comm as comm
    import pycocotools
    import torch

from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetMapper, build_detection_train_loader from detectron2.data.datasets.coco import load_coco_json from detectron2.engine import default_argument_parser, default_setup, launch from detectron2.engine import DefaultTrainer, BestCheckpointer from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, verify_results from detectron2.projects.point_rend import ColorAugSSDTransform, add_pointrend_config

Segmentation Model

config_file = './configs/pointrend_rcnn_R_50_FPN_3x_coco.yaml' weight_file = './model_final_rcnn_R_50_FPN_3x.pkl'

Dataset

cls_names = ['background', 'HPDF01', 'HAPP01'] train_path = './data/page_train' valid_path = './data/page_valid' train_json = os.path.join(train_path, 'annotations.json') valid_json = os.path.join(valid_path, 'annotations.json') num_train_data = 36

Result folder

output_dir = './ckpt'

def setup(args):

load based config

cfg = get_cfg()
# extend config parameter for PointRend NN
add_pointrend_config(cfg)
# update config setting for PointRend NN
cfg.merge_from_file(config_file)
cfg.merge_from_list(['MODEL.WEIGHTS', weight_file])

# update NN setting
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 3

# update data loader setting
cfg.DATASETS.TRAIN = ("page_train",)
cfg.DATASETS.TEST = ("page_valid",)
cfg.DATALOADER.NUM_WORKERS = 4
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False

cfg.INPUT.MAX_SIZE_TRAIN = 1024
cfg.INPUT.MIN_SIZE_TRAIN = (640, 1024)
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "range"
cfg.INPUT.MAX_SIZE_TEST = 1024
cfg.INPUT.MIN_SIZE_TEST = 1024

cfg.INPUT.RANDOM_FLIP = "none"
cfg.INPUT.CROP.ENABLED = False
cfg.INPUT.CROP.TYPE = "relative_range"
cfg.INPUT.CROP.SIZE = [0.8, 0.8]
cfg.INPUT.COLOR_AUG_SSD = True

# update hyper-parameter
cfg.SOLVER.IMS_PER_BATCH = 16
ITERS_PER_EPOCH = int(num_train_data/cfg.SOLVER.IMS_PER_BATCH)
MAX_ITER = ITERS_PER_EPOCH*500
cfg.SOLVER.MAX_ITER = MAX_ITER

cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.MOMENTUM = 0.9
cfg.SOLVER.STEPS = (int(MAX_ITER*0.6), int(MAX_ITER*0.8))
cfg.SOLVER.GAMMA = 0.1
cfg.SOLVER.WARMUP_ITERS = int(ITERS_PER_EPOCH*0.05)

# update evaluation/saving setting
#cfg.SOLVER.CHECKPOINT_PERIOD = ITERS_PER_EPOCH
cfg.TEST.EVAL_PERIOD = ITERS_PER_EPOCH
cfg.OUTPUT_DIR = output_dir

cfg.freeze()
default_setup(cfg, args)
return cfg

def plain_register_dataset(): DatasetCatalog.register("page_train", lambda: load_coco_json(train_json, train_path)) MetadataCatalog.get('page_train').set(thing_classes=cls_names, evaluator_type='coco', json_file=train_json, image_root=train_path) DatasetCatalog.register("page_valid", lambda: load_coco_json(valid_json, valid_path)) MetadataCatalog.get('page_valid').set(thing_classes=cls_names, evaluator_type='coco', json_file=valid_json, image_root=valid_path)

def build_trainAug(cfg): augs = [T.RandomRotation(angle=[0,90,180,270], sample_style="choice")] if cfg.INPUT.CROP.ENABLED: augs.append( T.RandomCrop( crop_type=cfg.INPUT.CROP.TYPE, crop_size=cfg.INPUT.CROP.SIZE ) ) augs.append(T.ResizeShortestEdge(short_edge_length=cfg.INPUT.MIN_SIZE_TRAIN, max_size=cfg.INPUT.MAX_SIZE_TRAIN, sample_style=cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING)) if cfg.INPUT.COLOR_AUG_SSD: augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))

return augs

class Trainer(DefaultTrainer): """ We use the "DefaultTrainer" which contains a number pre-defined logic for standard training workflow. They may not work for you, especially if you are working on a new research project. In that case you can use the cleaner "SimpleTrainer", or write your own training loop. """

@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
    """
    Create evaluator(s) for a given dataset.
    This uses the special metadata "evaluator_type" associated with each builtin dataset.
    For your own dataset, you can simply create an evaluator manually in your
    script and do not have to worry about the hacky if-else logic here.
    """
    if output_folder is None:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
    evaluator_list = []
    evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type

    if evaluator_type == "coco":
        evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))

    if len(evaluator_list) == 0:
        raise NotImplementedError(
            "no Evaluator for the dataset {} with the type {}".format(
                dataset_name, evaluator_type
            )
        )
    elif len(evaluator_list) == 1:
        return evaluator_list[0]
    return DatasetEvaluators(evaluator_list)

@classmethod
def build_train_loader(cls, cfg):
    mapper = DatasetMapper(cfg, is_train=True, augmentations=build_trainAug(cfg))

    train_dataLoader = build_detection_train_loader(cfg, mapper=mapper)
    #dump_augImage(cfg, train_dataLoader)
    return train_dataLoader

def build_hooks(self):
    cfg = self.cfg.clone()
    hooks = super().build_hooks()
    if comm.is_main_process():
        hooks.append(
            BestCheckpointer(
                eval_period=cfg.TEST.EVAL_PERIOD,
                checkpointer=DetectionCheckpointer(self.model, cfg.OUTPUT_DIR),
                val_metric="segm/AP",
                mode="max"
            )
        )
    return hooks

def main(args): cfg = setup(args) plain_register_dataset()

if args.eval_only:
    model = Trainer.build_model(cfg)
    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume
    )
    res = Trainer.test(cfg, model)
    if comm.is_main_process():
        verify_results(cfg, res)
    return res

trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()

if name == "main": args = default_argument_parser().parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )

2. What exact command you run:

CUDA_VISIBLE_DEVICES=0,1 python pointRend_train.py --num-gpus=2 --num-machines=1

3. __Full logs__ or other relevant observations:

[07/07 09:50:28] detectron2 INFO: Command line arguments: Namespace(config_file='', dist_url='tcp://127.0.0.1:51156', eval_only=False, machine_rank=0, num_gpus=2, num_machines=1, opts=[], resume=False) [07/07 09:50:28] detectron2 INFO: Running with full config: CUDNN_BENCHMARK: false DATALOADER: ASPECT_RATIO_GROUPING: true FILTER_EMPTY_ANNOTATIONS: false NUM_WORKERS: 4 REPEAT_THRESHOLD: 0.0 SAMPLER_TRAIN: TrainingSampler DATASETS: PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000 PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000 PROPOSAL_FILES_TEST: [] PROPOSAL_FILES_TRAIN: [] TEST:

[07/07 09:50:28] detectron2 INFO: Full config saved to ./ckpt5/config.yaml [07/07 09:50:29] d2.utils.env INFO: Using a generated random seed 30318684 [07/07 09:50:30] d2.engine.defaults INFO: Model: GeneralizedRCNN( (backbone): FPN( (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1)) (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (top_block): LastLevelMaxPool() (bottom_up): ResNet( (stem): BasicStem( (conv1): Conv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) ) (res2): Sequential( (0): BottleneckBlock( (shortcut): Conv2d( 64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv1): Conv2d( 64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv2): Conv2d( 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv3): Conv2d( 64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) ) (1): BottleneckBlock( (conv1): Conv2d( 256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv2): Conv2d( 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv3): Conv2d( 64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) ) (2): BottleneckBlock( (conv1): Conv2d( 256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv2): Conv2d( 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05) ) (conv3): Conv2d( 64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) ) ) (res3): Sequential( (0): BottleneckBlock( (shortcut): Conv2d( 256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv1): Conv2d( 256, 128, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv2): Conv2d( 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv3): Conv2d( 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) ) (1): BottleneckBlock( (conv1): Conv2d( 512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv2): Conv2d( 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv3): Conv2d( 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) ) (2): BottleneckBlock( (conv1): Conv2d( 512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv2): Conv2d( 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv3): Conv2d( 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) ) (3): BottleneckBlock( (conv1): Conv2d( 512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv2): Conv2d( 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05) ) (conv3): Conv2d( 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) ) ) (res4): Sequential( (0): BottleneckBlock( (shortcut): Conv2d( 512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) (conv1): Conv2d( 512, 256, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) (1): BottleneckBlock( (conv1): Conv2d( 1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) (2): BottleneckBlock( (conv1): Conv2d( 1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) (3): BottleneckBlock( (conv1): Conv2d( 1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) (4): BottleneckBlock( (conv1): Conv2d( 1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) (5): BottleneckBlock( (conv1): Conv2d( 1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05) ) (conv3): Conv2d( 256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05) ) ) ) (res5): Sequential( (0): BottleneckBlock( (shortcut): Conv2d( 1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05) ) (conv1): Conv2d( 1024, 512, kernel_size=(1, 1), stride=(2, 2), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv2): Conv2d( 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv3): Conv2d( 512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05) ) ) (1): BottleneckBlock( (conv1): Conv2d( 2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv2): Conv2d( 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv3): Conv2d( 512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05) ) ) (2): BottleneckBlock( (conv1): Conv2d( 2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv2): Conv2d( 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05) ) (conv3): Conv2d( 512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05) ) ) ) ) ) (proposal_generator): RPN( (rpn_head): StandardRPNHead( (conv): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) (activation): ReLU() ) (objectness_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) (anchor_deltas): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1)) ) (anchor_generator): DefaultAnchorGenerator( (cell_anchors): BufferList() ) ) (roi_heads): StandardROIHeads( (box_pooler): ROIPooler( (level_poolers): ModuleList( (0): ROIAlign(output_size=(7, 7), spatial_scale=0.25, sampling_ratio=0, aligned=True) (1): ROIAlign(output_size=(7, 7), spatial_scale=0.125, sampling_ratio=0, aligned=True) (2): ROIAlign(output_size=(7, 7), spatial_scale=0.0625, sampling_ratio=0, aligned=True) (3): ROIAlign(output_size=(7, 7), spatial_scale=0.03125, sampling_ratio=0, aligned=True) ) ) (box_head): FastRCNNConvFCHead( (flatten): Flatten(start_dim=1, end_dim=-1) (fc1): Linear(in_features=12544, out_features=1024, bias=True) (fc_relu1): ReLU() (fc2): Linear(in_features=1024, out_features=1024, bias=True) (fc_relu2): ReLU() ) (box_predictor): FastRCNNOutputLayers( (cls_score): Linear(in_features=1024, out_features=4, bias=True) (bbox_pred): Linear(in_features=1024, out_features=12, bias=True) ) (mask_head): PointRendMaskHead( (point_head): StandardPointHead( (fc1): Conv1d(259, 256, kernel_size=(1,), stride=(1,)) (fc2): Conv1d(259, 256, kernel_size=(1,), stride=(1,)) (fc3): Conv1d(259, 256, kernel_size=(1,), stride=(1,)) (predictor): Conv1d(259, 3, kernel_size=(1,), stride=(1,)) ) (coarse_head): ConvFCHead( (reduce_spatial_dim_conv): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2)) (fc1): Linear(in_features=12544, out_features=1024, bias=True) (fc2): Linear(in_features=1024, out_features=1024, bias=True) (prediction): Linear(in_features=1024, out_features=147, bias=True) ) ) ) ) [07/07 09:50:30] d2.data.dataset_mapper INFO: [DatasetMapper] Augmentations used in training: [RandomRotation(angle=[0, 90, 180, 270], sample_style='choice'), ResizeShortestEdge(short_edge_length=(640, 1024), max_size=1024), <detectron2.projects.point_rend.color_augmentation.ColorAugSSDTransform object at 0x7f1083128410>] [07/07 09:50:30] d2.data.datasets.coco INFO: Loaded 36 images in COCO format from /home/data/ctbc/nhi_segment3/page_train/annotations.json [07/07 09:50:30] d2.data.build INFO: Distribution of instances among all 3 categories:  category #instances category #instances category #instances
background 0 HPDF01 0 HAPP01 0
total 0 

[07/07 09:50:30] d2.data.build INFO: Using training sampler TrainingSampler [07/07 09:50:30] d2.data.common INFO: Serializing 36 elements to byte tensors and concatenating them all ... [07/07 09:50:30] d2.data.common INFO: Serialized dataset takes 0.01 MiB [07/07 09:50:30] fvcore.common.checkpoint INFO: [Checkpointer] Loading from /autohome/user/jason/project/aione_dev/seg_preTrain/model_final_rcnn_R_50_FPN_3x.pkl ... [07/07 09:50:31] fvcore.common.checkpoint INFO: Reading a file from 'Detectron2 Model Zoo' [07/07 09:50:31] d2.projects.point_rend.mask_head WARNING: Weight format of PointRend models have changed! Applying automatic conversion now ... [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (4,) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (12, 1024) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (12,) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.coarse_head.prediction.weight' to the model due to incompatible shapes: (3920, 1024) in the checkpoint but (147, 1024) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.coarse_head.prediction.bias' to the model due to incompatible shapes: (3920,) in the checkpoint but (147,) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.point_head.fc1.weight' to the model due to incompatible shapes: (256, 336, 1) in the checkpoint but (256, 259, 1) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.point_head.fc2.weight' to the model due to incompatible shapes: (256, 336, 1) in the checkpoint but (256, 259, 1) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.point_head.fc3.weight' to the model due to incompatible shapes: (256, 336, 1) in the checkpoint but (256, 259, 1) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.point_head.predictor.weight' to the model due to incompatible shapes: (80, 336, 1) in the checkpoint but (3, 259, 1) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Skip loading parameter 'roi_heads.mask_head.point_head.predictor.bias' to the model due to incompatible shapes: (80,) in the checkpoint but (3,) in the model! You might want to double check if this is expected. [07/07 09:50:31] fvcore.common.checkpoint WARNING: Some model parameters or buffers are not found in the checkpoint: roi_heads.box_predictor.bbox_pred.{bias, weight} roi_heads.box_predictor.cls_score.{bias, weight} roi_heads.mask_head.coarse_head.prediction.{bias, weight} roi_heads.mask_head.point_head.fc1.weight roi_heads.mask_head.point_head.fc2.weight roi_heads.mask_head.point_head.fc3.weight roi_heads.mask_head.point_head.predictor.{bias, weight} [07/07 09:50:31] d2.engine.train_loop INFO: Starting training from iteration 0 [07/07 09:50:35] d2.engine.train_loop ERROR: Exception during training: Traceback (most recent call last): File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 149, in train self.run_step() File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/engine/defaults.py", line 494, in run_step self._trainer.run_step() File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 273, in run_step loss_dict = self.model(data) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 886, in forward output = self.module(*inputs[0], *kwargs[0]) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/modeling/metaarch/rcnn.py", line 163, in forward , detector_losses = self.roi_heads(images, features, proposals, gt_instances) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/modeling/roi_heads/roi_heads.py", line 743, in forward losses.update(self._forward_mask(features, proposals)) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/modeling/roi_heads/roi_heads.py", line 846, in _forward_mask return self.mask_head(features, instances) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, **kwargs) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/projects/point_rend/mask_head.py", line 233, in forward point_coords, point_labels = self._sample_train_points(coarse_mask, instances) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/projects/point_rend/mask_head.py", line 285, in _sample_train_points point_labels = sample_point_labels(instances, point_coords_wrt_image) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/projects/point_rend/point_features.py", line 258, in sample_point_labels point_labels = cat(gt_mask_logits) File "/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2/layers/wrappers.py", line 45, in cat return torch.cat(tensors, dim) NotImplementedError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat. This usually means that this function requires a non-empty list of Tensors, or that you (the operator writer) forgot to register a fallback function. Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:18433 [kernel] CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:26496 [kernel] QuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:1068 [kernel] BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback] Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:47 [backend fallback] Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback] Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback] Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback] ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback] AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradMLC: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:10141 [autograd kernel] Tracer: registered at ../torch/csrc/autograd/generated/TraceType_3.cpp:11560 [kernel] UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:466 [backend fallback] Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:305 [backend fallback] Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback] VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

[07/07 09:50:35] d2.engine.hooks INFO: Total training time: 0:00:03 (0:00:00 on hooks) [07/07 09:50:35] d2.utils.events INFO: iter: 0 lr: N/A max_mem: 6422M


## Expected behavior:
No any runtimeError

## Environment:

Paste the output of the following command:

sys.platform linux Python 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0] numpy 1.21.6 detectron2 0.6 @/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/detectron2 Compiler GCC 7.3 CUDA compiler CUDA 10.2 detectron2 arch flags 3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5 DETECTRON2_ENV_MODULE PyTorch 1.10.1+cu102 @/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torch PyTorch debug build False GPU available Yes GPU 0,1 Tesla P100-PCIE-16GB (arch=6.0) Driver version 470.103.01 CUDA_HOME /usr/local/cuda Pillow 9.1.1 torchvision 0.11.2+cu102 @/autohome/user/jason/miniconda3/envs/NHI/lib/python3.7/site-packages/torchvision torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5 fvcore 0.1.5.post20220512 iopath 0.1.9 cv2 4.4.0


PyTorch built with:

jasonchenPJ commented 2 years ago

Hi @ppwwyyxx sorry for bothering you, do you have any suggestions for me about this issue?

shining-love commented 2 years ago

Hi @ppwwyyxx sorry for bothering you, do you have any suggestions for me about this issue?

hello,have you solved this problem?

jasonchenPJ commented 2 years ago

Not yet .... I still confused.

carlos-havier commented 1 year ago

Hi guys. I came across the same problem while training PointRend using RandomCrop as augmentation. I think that the issue is risen, similarly to you, when an image (or a crop in my case) has no positive labels in its masks.

I patched point_features.py in anaconda3\envs\hands\Lib\site-packages\detectron2-0.4.1-py3.9-win-amd64.egg\detectron2\projects\point_rend (or wherever your environment is) by changing:

point_labels = cat(gt_mask_logits)

for:

    sR, sP, s2 = point_coords.shape
    assert s2 == 2, point_coords.shape
    if gt_mask_logits:
        point_labels = cat(gt_mask_logits)
    else:
        point_labels = torch.zeros((sR, sP), dtype=point_coords.dtype, layout=point_coords.layout, device=point_coords.device)

Seems to be working fine with my data.

If it also works for you, I suggest someone patches it in Detectron2 and does a pull request.

Cheers

jasonchenPJ commented 1 year ago

Hi @carlos-havier Thanks for your reply. I base on your suggestion to retry the "full no object dataset" case. It could be work now.

Thank you so much~

emmanuel-nwogu commented 1 year ago

I am also having a similar error when trying to train using PointRend:

ERROR [10/19 22:25:15 d2.engine.train_loop]: Exception during training:
Traceback (most recent call last):
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/engine/train_loop.py", line 149, in train
    self.run_step()
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/engine/defaults.py", line 494, in run_step
    self._trainer.run_step()
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/engine/train_loop.py", line 274, in run_step
    loss_dict = self.model(data)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/modeling/meta_arch/rcnn.py", line 167, in forward
    _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/modeling/roi_heads/roi_heads.py", line 743, in forward
    losses.update(self._forward_mask(features, proposals))
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/modeling/roi_heads/roi_heads.py", line 846, in _forward_mask
    return self.mask_head(features, instances)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/projects/PointRend/point_rend/mask_head.py", line 233, in forward
    point_coords, point_labels = self._sample_train_points(coarse_mask, instances)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/projects/PointRend/point_rend/mask_head.py", line 285, in _sample_train_points
    point_labels = sample_point_labels(instances, point_coords_wrt_image)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/projects/PointRend/point_rend/point_features.py", line 258, in sample_point_labels
    point_labels = cat(gt_mask_logits)
  File "/content/gdrive/MyDrive/4mlab/detectron2_repo/detectron2/layers/wrappers.py", line 46, in cat
    return torch.cat(tensors, dim)
RuntimeError: torch.cat(): expected a non-empty list of Tensors

Training code:

from detectron2.checkpoint import DetectionCheckpointer

cfg = get_cfg()
# Add PointRend-specific config
point_rend.add_pointrend_config(cfg)
cfg.merge_from_file("/content/gdrive/MyDrive/4mlab/detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_edd263.pkl"
cfg.DATASETS.TRAIN = ("femur_train",)
cfg.DATASETS.TEST = ()  # "femur_val"
cfg.DATALOADER.NUM_WORKERS = 1
cfg.SOLVER.IMS_PER_BATCH = 1  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 100    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (femur). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.
cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
cfg.TEST.DETECTIONS_PER_IMAGE = 5
cfg.OUTPUT_DIR = r"/content/gdrive/MyDrive/4mlab/checkpoint"

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()
checkpointer = DetectionCheckpointer(trainer.model, save_dir=cfg.OUTPUT_DIR)
checkpointer.save("femur_detect_model")

Funnily enough, when I ran this same code months ago (Sept. 29, 2022), it all worked perfectly. I suspect a recent source code change may be suspect.

emmanuel-nwogu commented 1 year ago

I just figured out a surprising fix inspired by this issue comment. I changed cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 to cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 despite actually having one class. It works now :)

emmanuel-nwogu commented 1 year ago

Sorry if I derailed this issue. I should have maybe started a new issue. Also, the fix I mentioned above does not work for all cases. When I do use images with no masks like the author, the issue does persist. I haven't tried the fix by @carlos-havier yet. I'm hoping it's a good fix. Thanks, Carlos :)

emmanuel-nwogu commented 1 year ago

Just tried it. It does work great! Thanks again :) @carlos-havier

aymanaboghonim commented 1 year ago

thanks @carlos-havier , you saved my day.