I think the code "reference_points[: , : , None]" should be "reference_points[ : , None , : ]". Because the last dmension of "torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]" has a length of 4 and means [w_ratio,h_ratio,w_ratio,h_ratio] for [x,y,w,h],so the last dimension of "reference_points" should be [x,y,w,h].
However, in the case of "reference_points[: , : , None]" , the last dimension will be [x,x,x,x] or [y,y,y,y] 、 [w,w,w,w]、[h,h,h,h] after broadcast. Actually,in this case , the last dimension of "reference_points_input" has a length of 4, but reference_points_input[...,0] = reference_points_input[...,2],reference_points_input[...,1] = reference_points_input[...,3]. That means the result is [xw_ratio,xh_ratio,xw_ratio,xh_ratio].
But what we want to get should be [x w_ratio,y h_ratio,w w_ratio,h h_ratio].
The original code is :
https://github.com/IDEA-opensource/DAB-DETR/blob/309f6ad92af7a62d7732c1bdf1e0c7a69a7bdaef/models/dab_deformable_detr/deformable_transformer.py#L390 https://github.com/IDEA-opensource/DAB-DETR/blob/309f6ad92af7a62d7732c1bdf1e0c7a69a7bdaef/models/dab_deformable_detr/deformable_transformer.py#L391 https://github.com/IDEA-opensource/DAB-DETR/blob/309f6ad92af7a62d7732c1bdf1e0c7a69a7bdaef/models/dab_deformable_detr/deformable_transformer.py#L392
I think the code "reference_points[: , : , None]" should be "reference_points[ : , None , : ]". Because the last dmension of "torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]" has a length of 4 and means [w_ratio,h_ratio,w_ratio,h_ratio] for [x,y,w,h],so the last dimension of "reference_points" should be [x,y,w,h]. However, in the case of "reference_points[: , : , None]" , the last dimension will be [x,x,x,x] or [y,y,y,y] 、 [w,w,w,w]、[h,h,h,h] after broadcast. Actually,in this case , the last dimension of "reference_points_input" has a length of 4, but reference_points_input[...,0] = reference_points_input[...,2],reference_points_input[...,1] = reference_points_input[...,3]. That means the result is [xw_ratio,xh_ratio,xw_ratio,xh_ratio]. But what we want to get should be [x w_ratio,y h_ratio,w w_ratio,h h_ratio].
Is there something wrong with my understanding?