IDEA-Research / detrex

detrex is a research platform for DETR-based object detection, segmentation, pose estimation and other visual recognition tasks.
https://detrex.readthedocs.io/en/latest/
Apache License 2.0
1.97k stars 206 forks source link

how to change the number of classes #141

Closed buptxiaofeng closed 1 year ago

buptxiaofeng commented 1 year ago

Hi, thanks for this helpful project. I am using a custom dataset. All goes well, but the number of classes cannot be changed. I tried to add a config to the file projects/dab_detr/configs/dab_detr_r50_50ep.py by adding a line "model.num_classes=3", but it seems that this is not enough. The following errors will occur if I add this line to the config file:

File "/home/binxiao/code/py3_env/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 149, in train self.run_step() File "/home/binxiao/code/detrex/tools/train_net.py", line 96, in run_step loss_dict = self.model(data) File "/home/binxiao/code/py3_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/binxiao/code/py3_env/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward output = self.module(*inputs[0], *kwargs[0]) File "/home/binxiao/code/py3_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/binxiao/code/detrex/projects/dab_detr/modeling/dab_detr.py", line 198, in forward loss_dict = self.criterion(output, targets) File "/home/binxiao/code/py3_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/binxiao/code/detrex/detrex/modeling/criterion/criterion.py", line 230, in forward losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) File "/home/binxiao/code/detrex/detrex/modeling/criterion/criterion.py", line 198, in get_loss return loss_map[loss](outputs, targets, indices, num_boxes, kwargs) File "/home/binxiao/code/detrex/detrex/modeling/criterion/criterion.py", line 162, in loss_boxes src_boxes = outputs["pred_boxes"][idx] RuntimeError: CUDA error: device-side assert triggered

buptxiaofeng commented 1 year ago

It turns out that I should modify the file projects/dab_detr/configs/models/dab_detr_r50.py. So I close this issue.

rentainhe commented 1 year ago

It turns out that I should modify the file projects/dab_detr/configs/models/dab_detr_r50.py. So I close this issue.

just modify the config as model.num_classes = n, and I'm closing this issue, feel free to reopen it if needed

zldrobit commented 1 year ago

@rentainhe What about model.criterion.num_classes? Is it necessary to change it as well?

rentainhe commented 1 year ago

@rentainhe What about model.criterion.num_classes? Is it necessary to change it as well?

Yes, it should be changed with the model.num_classes