MiniBullLab / easy_ai

3 stars 1 forks source link

segnet精度损失 #111

Closed foww-0001 closed 3 years ago

foww-0001 commented 3 years ago

develop分支

Epoch: 82 | mIoU: 0.776 | background: 0.997 lane: 0.555

edge_tools分支

Epoch: 19 | mIoU: 0.885 | background: 0.999 lane: 0.770
foww-0001 commented 3 years ago

data里面中的segnet.pt和代码中的segnet模型层名称不匹配,因此模型参数未导入。替换data中的segnet.pt后精度恢复。

foww-0001 commented 3 years ago

问题已解决。

foww-0001 commented 3 years ago

test的分支需要进行测试。

foww-0001 commented 3 years ago

test分支上的训练loss为负:

Epoch: 0[0/52]   Loss: -1.4999257        Rate: 0.0000000         Time: 4.09368  
Epoch: 0[1/52]   Loss: -2.6540663        Rate: 0.0000000         Time: 0.15645  
Epoch: 0[2/52]   Loss: 0.6849564         Rate: 0.0000000         Time: 0.12994  
Epoch: 0[3/52]   Loss: -2.3447042        Rate: 0.0000000         Time: 0.13117  
Epoch: 0[4/52]   Loss: -0.4011593        Rate: 0.0000000         Time: 0.13032  
Epoch: 0[5/52]   Loss: -0.9078372        Rate: 0.0000000         Time: 0.13245  
Epoch: 0[6/52]   Loss: -3.8072939        Rate: 0.0000000         Time: 0.13036  
Epoch: 0[7/52]   Loss: -0.9506326        Rate: 0.0000000         Time: 0.13034  
Epoch: 0[8/52]   Loss: -0.2740065        Rate: 0.0000000         Time: 0.13226  
Epoch: 0[9/52]   Loss: -1.3429866        Rate: 0.0000000         Time: 0.13053  
Epoch: 0[10/52]  Loss: -0.5071906        Rate: 0.0000000         Time: 0.13104  

正在定位问题中。

foww-0001 commented 3 years ago

其中segnet运行脚本如下:

 ./easy_tools/train_scripts/SegNet.sh /home/wfw/data/VOCdevkit/CarScratch_segment/ImageSets/train.txt /home/wfw/data/VOCdevkit/CarScratch_segment/ImageSets/val.txt

config文件比对后没有异常,模型参数确认正确无误。

foww-0001 commented 3 years ago

master训练正常,确认数据无误。

foww-0001 commented 3 years ago

定位出来是loss的问题。

foww-0001 commented 3 years ago

定位ce2d_loss.py中加入reduce=False。

if self.weight_type == 0:
    loss = F.binary_cross_entropy(input_data, targets,
                                  weight=self.weight,
                                  reduction=self.reduction)
else:
    loss = F.binary_cross_entropy(input_data, targets,
                                  + reduce=False,
                                  reduction=self.reduction)
foww-0001 commented 3 years ago

data模型已经上传118服务器。精度问题已经解决。