Jingkang50 / OpenPSG

Benchmarking Panoptic Scene Graph Generation (PSG), ECCV'22
https://psgdataset.org
MIT License
409 stars 68 forks source link

Cannot use `rel_loss_cls = dict(use_sigmoid=True)` #55

Closed ShunchiZhang closed 1 year ago

ShunchiZhang commented 2 years ago

When add line rel_loss_cls = dict(use_sigmoid=True) in PSGTr config file, the following error occurs:

RuntimeError: The size of tensor a (57) must match the size of tensor b (56) at non-singleton dimension 1

Full Log ``` Traceback (most recent call last): File "tools/train.py", line 225, in main() File "tools/train.py", line 220, in main meta=meta, File "/opt/conda/lib/python3.7/site-packages/mmdet/apis/train.py", line 209, in train_detector runner.run(data_loaders, cfg.workflow) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run epoch_runner(data_loaders[i], **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train self.run_iter(data_batch, train_mode=True, **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step return self.module.train_step(*inputs[0], **kwargs[0]) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 248, in train_step losses = self(**data) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func return old_func(*args, **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 172, in forward return self.forward_train(img, img_metas, **kwargs) File "/home/shunchi/OpenPSG/openpsg/models/frameworks/psgtr.py", line 139, in forward_train gt_bboxes_ignore) File "/home/shunchi/OpenPSG/openpsg/models/relation_heads/psgtr_head.py", line 871, in forward_train losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 186, in new_func return old_func(*args, **kwargs) File "/home/shunchi/OpenPSG/openpsg/models/relation_heads/psgtr_head.py", line 399, in loss img_metas_list, all_gt_bboxes_ignore_list) File "/opt/conda/lib/python3.7/site-packages/mmdet/core/utils/misc.py", line 30, in multi_apply return tuple(map(list, zip(*map_results))) File "/home/shunchi/OpenPSG/openpsg/models/relation_heads/psgtr_head.py", line 548, in loss_single avg_factor=cls_avg_factor) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/losses/cross_entropy_loss.py", line 250, in forward **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/losses/cross_entropy_loss.py", line 108, in binary_cross_entropy pred, label.float(), pos_weight=class_weight, reduction='none') File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 2982, in binary_cross_entropy_with_logits return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) RuntimeError: The size of tensor a (57) must match the size of tensor b (56) at non-singleton dimension 1 ```

I wonder whether it's because the rel_cls_out_channels becomes num_relations instead of num_relations + 1. So, what is the difference between use_sigmoid = True and use_sigmoid = False and why it influences rel_cls_out_channels?

Thank you so much!

Jingkang50 commented 1 year ago

That seems weird to me. use_sigmoid should not affect the out channels. Did you change other stuff?

ShunchiZhang commented 1 year ago

Sorry for the late reply.

  1. I am quite sure I changed nothing from the original code;
  2. This snippets suggests that use_sigmoid will affect the value of *_out_channels. https://github.com/Jingkang50/OpenPSG/blob/23f9c89d43eeaa6e4ee57895d11ac9fb3546fd47/openpsg/models/relation_heads/psgtr_head.py#L187-L200
ShunchiZhang commented 1 year ago

I guess a potential bug is that class labels in gt_rel starts from 1 rather than 0, therefore the reduction of cls_out_channels could lead to IndexErrors (as it doesn't happen on subject and object, but only relation). However I am not sure about the associated codes with this special setting, much appreciated if you could point them out.

Jingkang50 commented 1 year ago

I am not quite sure too. The code is written by mmdet. To avoid this, what I would suggest is to write a loss function by your own.

ShunchiZhang commented 1 year ago

Thanks. I'll try to figure it out by myself.

Jingkang50 commented 1 year ago

@GSeanCDAT You could also comment here.

GSeanCDAT commented 1 year ago

@ShunchiZhang Hi there, I think the possible solution might be to turn the gt_rels into one-hot before calculating the loss. (make sure you set the first element of one-hot as the background class).