HikariTJU / LD

Localization Distillation for Object Detection (CVPR 2022, TPAMI 2023)
Apache License 2.0
355 stars 51 forks source link

请问如果使用自己的feature-imitation方法应该在哪里添加,您的feature-imitation代码写在哪里了,我没有找到 #60

Open yg333 opened 1 year ago

HikariTJU commented 1 year ago

计算imitation区域的在 https://github.com/HikariTJU/LD/blob/bc60bbcd48e9305e61e32a0d2c981621e2ae0d05/mmdet/models/dense_heads/ld_head.py#L580

计算loss的在 https://github.com/HikariTJU/LD/blob/bc60bbcd48e9305e61e32a0d2c981621e2ae0d05/mmdet/models/dense_heads/ld_head.py#L170

yg333 commented 1 year ago

请问我的环境都是对的 但是您的源码训练就报错,弄了好几天,实在没辙了

HikariTJU commented 1 year ago

请使用标准模板报告bug, 我推荐你可以先试一下单卡能不能跑

python tools/train.py ${config}

以下是标准模板

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug A clear and concise description of what the bug is.

Reproduction

  1. What command or script did you run?
    A placeholder for the command.
  2. Did you make any modifications on the code or config? Did you understand what you have modified?
  3. What dataset did you use?

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
  2. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback If applicable, paste the error trackback here.

A placeholder for trackback.

Bug fix If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

yg333 commented 1 year ago

谢谢您,单卡训练就可以了。还有个问题我没搞清楚。比如我想用Masked Generative Distillation论文的特征图蒸馏方法。我只需要改这段代码就可以了吗?谢谢您的及时回复

  if self.imitation_method == 'gibox':
            gi_idx = self.get_gi_region(soft_label, cls_score, anchors,
                                        bbox_pred, soft_targets, stride)
            gi_teacher = teacher_x[gi_idx]
            gi_student = x[gi_idx]

            loss_im = self.loss_im(gi_student, gi_teacher)
HikariTJU commented 1 year ago
  1. 单卡能跑多卡不能跑说明,你可能多卡的launch命令错了
  2. 你需要写一个函数计算蒸馏区域,然后把计算loss改成对应蒸馏区域的教师和学生特征图。很难用文字描述,你需要自己意会一下,参考get_im_region函数
yg333 commented 1 year ago

可是特征图蒸馏区域不是整个feature map吗,经过一系列网络模块得到最终的学生特征图后,直接与教师特征图做L2损失。我不太理解您说的这个蒸馏区域是啥?