JosephKJ / OWOD

(CVPR 2021 Oral) Open World Object Detection
https://josephkj.in
Apache License 2.0
1.04k stars 155 forks source link

IndexError: list index out of range when CUR_INTRODUCED_CLS > 84 #125

Closed sf-pear closed 1 year ago

sf-pear commented 1 year ago

Hello,

I'm trying to train OWOD on a custom dataset, but I'm running into an error I can't figure out how to fix. I have 290 classes in the dataset, so this is how I set up my config file:

_BASE_: "Base-RCNN-C4-OWOD.yaml"
MODEL:
  WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
  TRAIN: ('kaggle_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
  TEST: ('kaggle_eval_without_unknowns', 'kaggle_eval_with_unknown')   # voc_coco_2007_test, t1_voc_coco_2007_test, t1_voc_coco_2007_val
SOLVER:
  STEPS: (12000, 16000)
  MAX_ITER: 18000
  WARMUP_ITERS: 100
OUTPUT_DIR: "./output/kagglev2-20230710"
OWOD:
  PREV_INTRODUCED_CLS: 0
  CUR_INTRODUCED_CLS: 290

I also experimented a bit and found that CUR_INTRODUCED_CLS accepts up to 84, when I try 85 I start getting the IndexError. Has anyone experienced this? I'm trying to understand where this error comes from so I can fix it.

Note: If I set CUR_INTRODUCED_CLS to less than 290 I get the incorrect distribution of labels for my train set.

Here is the full error:

Traceback (most recent call last):
  File "tools/train_net.py", line 169, in <module>
    args=(args,),
  File "/home/sabrina/code/OWOD/detectron2/engine/launch.py", line 59, in launch
    daemon=False,
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/home/sabrina/code/OWOD/detectron2/engine/launch.py", line 94, in _distributed_worker
    main_func(*args)
  File "/home/sabrina/code/OWOD/tools/train_net.py", line 157, in main
    return trainer.train()
  File "/home/sabrina/code/OWOD/detectron2/engine/defaults.py", line 408, in train
    super().train(self.start_iter, self.max_iter)
  File "/home/sabrina/code/OWOD/detectron2/engine/train_loop.py", line 147, in train
    self.run_step()
  File "/home/sabrina/code/OWOD/detectron2/engine/train_loop.py", line 317, in run_step
    loss_dict = self.model(data)
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 511, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sabrina/code/OWOD/detectron2/modeling/meta_arch/rcnn.py", line 166, in forward
    _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
  File "/home/sabrina/mambaforge/envs/ow/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sabrina/code/OWOD/detectron2/modeling/roi_heads/roi_heads.py", line 478, in forward
    self.box_predictor.update_feature_store(input_features, proposals)
  File "/home/sabrina/code/OWOD/detectron2/modeling/roi_heads/fast_rcnn.py", line 560, in update_feature_store
    self.feature_store.add(features, gt_classes)
  File "/home/sabrina/code/OWOD/detectron2/utils/store.py", line 14, in add
    self.store[class_id].append(items[idx])
IndexError: list index out of range