MCG-NJU / AdaMixer

[CVPR 2022 Oral] AdaMixer: A Fast-Converging Query-Based Object Detector
MIT License
236 stars 24 forks source link

RuntimeError: output with shape [16, 100, 100] doesn't match the broadcast shape [1, 16, 100, 100] #14

Closed Fangyi-Chen closed 2 years ago

Fangyi-Chen commented 2 years ago

Hi I am running the training code and I meet the following error; I am using pytorch1.3.1 torchvision0.4.2 cuda10.1

File "AdaMixer/mmdet/models/roi_heads/adamixer_decoder.py", line 60, in _bbox_forward featmap_strides=self.featmap_strides) File "/home/python3.7/site-packages/torch/nn/modules/module.py", line 541, in call result = self.forward(*input, kwargs) File "/home/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 95, in new_func return old_func(*args, *kwargs) File "AdaMixer/mmdet/models/roi_heads/bbox_heads/adamixer_decoder_stage.py", line 302, in forward attn_mask=attn_bias, File "/home/python3.7/site-packages/torch/nn/modules/module.py", line 541, in call result = self.forward(input, kwargs) File "/home/python3.7/site-packages/mmcv/cnn/bricks/transformer.py", line 137, in forward key_padding_mask=key_padding_mask)[0] File "/home/python3.7/site-packages/torch/nn/modules/module.py", line 541, in call result = self.forward(*input, **kwargs) File "/home/python3.7/site-packages/torch/nn/modules/activation.py", line 783, in forward attn_mask=attn_mask) File "/home/python3.7/site-packages/torch/nn/functional.py", line 3352, in multi_head_attention_forward attn_output_weights += attn_mask RuntimeError: output with shape [16, 100, 100] doesn't match the broadcast shape [1, 16, 100, 100]

Looking forward to your reply and many thanks. And can you please specify your pytorch and cuda version?

sebgao commented 2 years ago

Please install pytorch>=1.5.0 (#12).

If you do not want to change the pytorch version, changing the line blow https://github.com/MCG-NJU/AdaMixer/blob/fd0f721c45596da4615900a57fa43d4bcddb275a/mmdet/models/roi_heads/bbox_heads/adamixer_decoder_stage.py#L294 to

attn_bias = (iof * self.iof_tau.view(1, -1, 1, 1))

may help.