Fangyi-Chen / SQR

MIT License
102 stars 5 forks source link

how to test the SQR decoder individually ? #8

Open Originlightwkp opened 1 year ago

Originlightwkp commented 1 year ago

I would like to use the SQR decoder in your paper to improve the decoder in my model. In order not to affect subsequent tasks, I need to test the decoder separately.

I have written a test code, but there are still some issues. Do you have any methods to test the decoder separately?

if __name__ == '__main__':
x = torch.rand((300, 4, 256)).cuda()
valid_ratios = torch.rand((4,1,2))
reference_points = torch.rand((4, 300, 4))
reg_branches = None
model = QRDeformableDetrTransformerDecoder(num_layers=6,
            return_intermediate=True,
            start_q=[0, 0, 1, 2, 4, 7, 12], # 2
            end_q=[1, 2, 4, 7, 12, 20, 33], # 2
            transformerlayers=dict(
                type='DetrTransformerDecoderLayer',
                attn_cfgs=[
                    dict(
                        type='MultiheadAttention',
                        embed_dims=256,
                        num_heads=8,
                        dropout=0.1),
                    dict(
                        type='MultiScaleDeformableAttention',
                        embed_dims=256)
                ],
                feedforward_channels=1024,
                ffn_dropout=0.1,
                operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                                 'ffn', 'norm'))).cuda()
out = model(x,valid_ratios=valid_ratios,reference_points=reference_points,reg_branches=None)
print(out[0].shape)
print(out[1].shape)

The above code reported an error:


Traceback (most recent call last): File "/media/cheng/dataset4/codeto2022/FairMOT_modifed/src/sqrtest.py", line 381, in out = model(x,valid_ratios=valid_ratios,reference_points=reference_points,reg_branches=None) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/media/cheng/dataset4/codeto2022/FairMOT_modifed/src/sqrtest.py", line 223, in forward kwargs) File "/media/cheng/dataset4/codeto2022/FairMOT_modifed/src/sqrtest.py", line 144, in forward kwargs) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/mmcv/cnn/bricks/transformer.py", line 850, in forward kwargs) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/mmcv/utils/misc.py", line 340, in new_func output = old_func(*args, *kwargs) File "/media/cheng/dataset4/annaconda3/envs/fashionclip/lib/python3.7/site-packages/mmcv/ops/multi_scale_deform_attn.py", line 312, in forward assert (spatial_shapes[:, 0] spatial_shapes[:, 1]).sum() == num_value TypeError: 'NoneType' object is not subscriptable