xiuqhou / Salience-DETR

[CVPR 2024] Official implementation of the paper "Salience DETR: Enhancing Detection Transformer with Hierarchical Salience Filtering Refinement"
https://arxiv.org/abs/2403.16131
Apache License 2.0
111 stars 7 forks source link

pos_embed算子输入的shape不匹配 #34

Open CerrieJ opened 1 month ago

CerrieJ commented 1 month ago

Question

跑训练过程遇到 pos_embed算子输入的shape不匹配,请教大概是什么原因呢?我是pt2.3,其他以来版本是requirments.txt中内容

[rank5]: File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank5]: return forward_call(*args, **kwargs) [rank5]: File "/gpfs/xj/Salience-DETR_mlu/models/bricks/salience_transformer.py", line 565, in forward [rank5]: query_with_pos = key_with_pos = self.with_pos_embed(query, query_pos) [rank5]: File "/gpfs/xj/Salience-DETR_mlu/models/bricks/salience_transformer.py", line 545, in with_pos_embed [rank5]: return tensor if pos is None else tensor + pos [rank5]: RuntimeError: The size of tensor a (1092) must match the size of tensor b (192) at non-singleton dimension 1

补充信息

No response

xiuqhou commented 1 month ago

这里报错是因为decoder中query和query_pos维度不匹配,正常情况下两者应该都是1092,其中192是去噪部分,900是匹配部分。但这里似乎query_pos缺少了匹配部分,请问您有修改过代码吗?跑的是coco数据集还是自己数据集呢?另外能否具体说明下是什么训练配置下出现的这个问题?

xiuqhou commented 1 month ago

query和query_pos分别来自targets和reference_points,可以调试检查下salience_transformer.py第220行-221行的targets和reference_points的维度是不是匹配,他俩前两个维度应该是相同的

CerrieJ commented 1 month ago

我使用的官方coco2017,按照readme 跑的单卡训练任务。我发现可能是因为我没有编译 models/bricks/ops/cuda/ms_deform_attn_cuda.cu,因为我当前不是在NVIDIA-GPU上运行,想请教一下,是否有办法绕开这个cudac算子,比如使用pytorch已有的算子替换?

xiuqhou commented 1 month ago

算子如果编译失败就会自动用pytorch原生算子替换,应该不用修改。逻辑在ms_deform_attn.py前几行。

报错是在decoder发生的,感觉不是cuda算子问题,否则应该encoder就会报错

CerrieJ commented 1 month ago

(Pdb) p reference_points.size() torch.Size([2, 200, 4]) (Pdb) p target.size() torch.Size([2, 1100, 256]) 我的target 和 reference_points 的第二个维度确实不匹配,请问进一步排查方向是什么呢?

xiuqhou commented 1 month ago

reference_points一部分来自noised_box_query,另一部分来自enc_outputs_coord,看起来是缺少了来自enc_outputs_coord的部分。

请在salience_transformer.py第203行和215行分别打个断点(请注意对应到您修改后的代码行数),看一看他们的维度是否是下面这样:

203行:中间维度可能不是18239,只要他俩的数字相同就行。

图片

215行:中间维度必须都是900

图片

如果维度和我这边的不同,说明问题应该是出现在203-212行,可能是计算topk_index的过程出现了问题