facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
29.91k stars 7.4k forks source link

Training panoptic segmentation model on COCO data fails #4899

Open uziela opened 1 year ago

uziela commented 1 year ago

Instructions To Reproduce the Issue:

I want to train the panoptic segmentation model on a custom data set, however, before I prepare my own custom data set I wanted to make sure that the training works on the COCO dataset. I saw @ppwwyyxx comment that training panoptic segmentation on a custom data set is not supported but I saw other people discussing about successfully doing that so I thought I would try that, too.

Since I don't need a large data set for this trial, I decided to train on COCO val2017 data set (rather than train2017). I downloaded these zip files: http://images.cocodataset.org/zips/val2017.zip http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip

I am using detectron2.data.datasets.register_coco_panoptic_separated as suggested in 1691

However, when I try to run the code below, I get an error "IndexError: Target 202 is out of bounds." from torch._C._nn.cross_entropy_loss.

As far as I understand, COCO data set contains 80 "thing" classes and 53 "stuff" classes. The category IDs in panoptic_val2017.json are from 1 to 200. So the number of targets should be either 80 + 53 + 1 = 134 or 200 at most. So where this "Target 202" come from???

Interestingly, if I set cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 203, the model trains successfully (not sure if the resulting model is meaningful though).

Another interesting point is that if I uncomment the line sem_seg_root = "/scratch2/coco/from_website/panoptic2017/panoptic_annotations_trainval2017/annotations/panoptic_val2017/panoptic_val2017/"

Then even setting cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 250 does not help, because I get error "IndexError: Target 251 is out of bounds."

So since changing sem_seg_root changes the number of targets, I assume that the number of targets comes from semantic segmentation bitmaps. Is it equal to the number of different colors in semantic bitmaps? For panoptic masks the bitmap colors are mapped to "segment_info.id" field in panoptic_val2017.json using formula ids=R+G256+B256^2 (as instructed in https://cocodataset.org/#format-data), however, I'm not sure how semantic segmentation mask colors are mapped to category info in the json files.

A related question: does the custom panoptic segmentation training pipeline use panoptic bitmaps (panoptic_root) at all, or does it only use semantic bitmaps (sem_seg_root)?

  1. Full runnable code or full changes you made:
    
    from detectron2.data.datasets import register_coco_panoptic_separated

from detectron2 import model_zoo from detectron2.engine import DefaultTrainer from detectron2.config import get_cfg

image_root = "/scratch2/coco/from_website/val2017/" panoptic_root = "/scratch2/coco/from_website/panoptic2017/panoptic_annotations_trainval2017/annotations/panoptic_val2017/panoptic_val2017/" # from http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip panoptic_json = "/scratch2/coco/from_website/panoptic2017/panoptic_annotations_trainval2017/annotations/panoptic_val2017.json" # from http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip sem_seg_root = "/scratch2/coco/from_website/stuff2017/stuff_annotations_trainval2017/annotations/stuff_val2017_pixelmaps/" # from http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip

sem_seg_root = "/scratch2/coco/from_website/panoptic2017/panoptic_annotations_trainval2017/annotations/panoptic_val2017/panoptic_val2017/"

instances_json = "/scratch2/coco/from_website/val2017/annotations/instances_val2017.json" # from http://images.cocodataset.org/zips/val2017.zip

register_coco_panoptic_separated("my_dataset", {}, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json)

cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml")) cfg.MODEL.DEVICE='cpu' cfg.DATASETS.TRAIN = ("my_dataset_separated",)

cfg.MODEL.ROI_HEADS.NUM_CLASSES = 53

cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 200

cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 203

cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 250

cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 1 # This is the real "batch size" commonly known to deep learning people.

cfg.INPUT.MASK_FORMAT = "bitmask"

trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False)

trainer.train()

2. What exact command you run:

I saved the above commands in train_model_panoptic_issue.py and ran:
python train_model_panoptic_issue.py

3. __Full logs__ or other relevant observations:

[04/06 10:10:14 d2.engine.hooks]: Total training time: 0:00:03 (0:00:00 on hooks) [04/06 10:10:14 d2.utils.events]: iter: 0 lr: N/A
Traceback (most recent call last): File "train_model_panoptic_issue.py", line 38, in trainer.train() File "/usr/local/lib/python3.8/dist-packages/detectron2/engine/defaults.py", line 484, in train super().train(self.start_iter, self.max_iter) File "/usr/local/lib/python3.8/dist-packages/detectron2/engine/train_loop.py", line 149, in train self.run_step() File "/usr/local/lib/python3.8/dist-packages/detectron2/engine/defaults.py", line 494, in run_step self._trainer.run_step() File "/usr/local/lib/python3.8/dist-packages/detectron2/engine/train_loop.py", line 274, in run_step loss_dict = self.model(data) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/usr/local/lib/python3.8/dist-packages/detectron2/modeling/meta_arch/panoptic_fpn.py", line 127, in forward sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, **kwargs) File "/usr/local/lib/python3.8/dist-packages/detectron2/modeling/meta_arch/semantic_seg.py", line 239, in forward return None, self.losses(x, targets) File "/usr/local/lib/python3.8/dist-packages/detectron2/modeling/meta_arch/semantic_seg.py", line 263, in losses loss = F.cross_entropy( File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2996, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) IndexError: Target 202 is out of bounds.


## Expected behavior:

The expected behavior is that model training works when the number of segmentation heads is set to num_thing_classes + num_stuff_classes + 1: 
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 80 + 53 + 1

## Environment:

Paste the output of the following command:

2023-04-06 10:16:41 URL:https://raw.githubusercontent.com/facebookresearch/detectron2/main/detectron2/utils/collect_env.py [8525/8525] -> "collect_env.py" [1] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


sys.platform linux Python 3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0] numpy 1.22.3 detectron2 0.6 @/usr/local/lib/python3.8/dist-packages/detectron2 Compiler GCC 9.4 CUDA compiler not available DETECTRON2_ENV_MODULE PyTorch 1.11.0+cu102 @/usr/local/lib/python3.8/dist-packages/torch PyTorch debug build False GPU available No: torch.cuda.is_available() == False Pillow 9.3.0 torchvision 0.12.0+cu102 @/usr/local/lib/python3.8/dist-packages/torchvision fvcore 0.1.5.post20221213 iopath 0.1.9 cv2 4.5.4


PyTorch built with:

uziela commented 1 year ago

OK, it seems I have solved my problem. It appears that you cannot use semantic annotation from http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip to train detectron2 panoptic segmentation model. These semantic annotations are not in the right format (the pixel values do not correspond to the category ids as they should).

I actually used separate_coco_semantic_from_panoptic function from https://github.com/IDEA-Research/MaskDINO/blob/main/datasets/prepare_coco_semantic_annos_from_panoptic_annos.py to extract semantic annotations from the panoptic annotations. What it does is converts the panoptic bitmaps to the semantic bitmaps, so that pixel values in semantic bitmaps correspond to category ids (in the range 1-133). The training now works.

It would be great to have more documentation on how to prepare data for panoptic segmentation training, though. Is it still unsupported? Also, I'm still curious if panoptic masks are used in any way during the training and if not, why do we need them when registering the dataset using register_coco_panoptic_separated function.

Phylanxy commented 7 months ago

The panoptic annotations are used for evaluation. #1691 (comment)