chenbinghui1 / DSL

CVPR2022 paper "Dense Learning based Semi-Supervised Object Detection"
Apache License 2.0
101 stars 10 forks source link

train debug #8

Closed weyoung0 closed 2 years ago

weyoung0 commented 2 years ago

换用自己数据集训练时报错:

Traceback (most recent call last): File "./tools/train.py", line 202, in main() File "./tools/train.py", line 190, in main train_detector( File "/secret/ZLW/Codes/SSOD/DSL/mmdet/apis/train.py", line 218, in train_detector runner.run(data_loaders, cfg.workflow) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/runner/hooks/semi_epoch_based_runner.py", line 344, in run epoch_runner(data_loaders[i], kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/runner/hooks/semi_epoch_based_runner.py", line 265, in train self.run_iter(data_batch, train_mode=True, kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/runner/hooks/semi_epoch_based_runner.py", line 155, in run_iter outputs = self.model.train_step(data_batch, self.optimizer, File "/usr/local/lib/python3.8/site-packages/mmcv/parallel/distributed.py", line 52, in train_step output = self.module.train_step(inputs[0], kwargs[0]) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/detectors/base.py", line 237, in train_step losses = self(data) File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/usr/local/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 97, in new_func return old_func(args, kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/detectors/base.py", line 171, in forward return self.forward_train(img, img_metas, kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/detectors/single_stage.py", line 82, in forward_train losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/dense_heads/base_dense_head.py", line 54, in forward_train losses = self.loss(loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) File "/usr/local/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 185, in new_func return old_func(*args, *kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/dense_heads/fcos_head.py", line 309, in loss loss_cls = self.loss_cls( File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/losses/focal_loss.py", line 170, in forward loss_cls = self.loss_weight * calculate_loss_func( File "/secret/ZLW/Codes/SSOD/DSL/mmdet/models/losses/focal_loss.py", line 85, in sigmoid_focal_loss loss = _sigmoid_focal_loss(pred.contiguous(), target, gamma, alpha, None, File "/usr/local/lib/python3.8/site-packages/mmcv/ops/focal_loss.py", line 39, in forward assert input.size(0) == target.size(0) AssertionError

batch设为8,输入分辨率设为512x512,debug了一下,发现在semi_epoch_based_runner.py第186行开始, data_batch['img_metas']、data_batch['gt_bboxes']data_batch['gt_labels']添加了一个元素,而data_batch['img'] cat了一个batch-1的图像tensor。导致网络的模型输入tensor维度变成(15,3,512,512),而label相关的信息为9张图像的,进而在计算loss时出现了AssertionError。 请问大佬这里是我代码没理解对还是确实有bug呢?

chenbinghui1 commented 2 years ago

@weyoung0 建议用batch=2 用别的话修改这个地方成1张图就行 就行 https://github.com/chenbinghui1/DSL/blob/45ee8fd1bc267f8d9fb1763d4979d7b0a9efc989/mmdet/runner/hooks/semi_epoch_based_runner.py#L193

weyoung0 commented 2 years ago

@chenbinghui1 感谢大佬的快速回复,请问batch为2有什么说法吗?因为我输入尺寸调的512,batch为2只占一丁点显存,浪费好多GPU,事实上我训的时候调成了72。。还有就是监督baseline训出来的指标和我自己的模型差好多,不知道哪里出了问题

chenbinghui1 commented 2 years ago

@chenbinghui1 因为当前代码只扩增了一张无标注数据,batch太多其他的unlabel数据都没有得到扩增,所以相当于没有怎么用这个Lscale。另外 监督模型那边 那就得对齐你的模型了,监督模型就是普通fcos的resnet50没有别的东西;可能是augmentation不同把。

weyoung0 commented 2 years ago

@chenbinghui1 好的,我再对齐下监督模型试试,非常感谢!