bubbliiiing / yolov7-pytorch

这是一个yolov7的库,可以用于训练自己的数据集。
GNU General Public License v3.0
861 stars 150 forks source link

运行 predict 报错 yolo.py文件第109行相关修改,请问是否合理 #61

Closed ClarkeAC closed 1 year ago

ClarkeAC commented 1 year ago

运行环境:4个显卡,torch182

在不指定显卡的情况下,会出现以下错误:

ValueError: Cannot assign non-leaf Tensor to parameter 'weight'. Model parameters must be created explicitly. To express 'weight' as a function of another Tensor, compute the value in the forward() method.

在指定显卡的情况下(CUDA_VISIBLE_DEVICES=0),不会出现错误

在别的issue中看到了这个

    将yolo.py文件第109行,self.net = self.net.fuse().eval()替换为self.net = self.net.eval()后问题解决了,能正常评估和预测。请问这两者的区别是什么呀?

Originally posted by @runge2020 in https://github.com/bubbliiiing/yolov7-pytorch/issues/6#issuecomment-1181418693

修改./net/yolo.py 的202和206行,可以不再报错且无需上述修改。请问这个修改方式是否合理?谢谢

def fuse_conv_and_bn(conv, bn):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    w_conv  = conv.weight.clone().view(conv.out_channels, -1)
    w_bn    = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    # fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape).detach())

    b_conv  = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn    = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    # fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
    fusedconv.bias.copy_((torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn).detach())
    return fusedconv
bubbliiiing commented 1 year ago

结果一样吗,我感觉没什么问题诶

ClarkeAC commented 1 year ago

好的,谢谢 检测的结果好像没有区别,至少可以避免那个问题,并且检测速度和预期(有fuse()的)一致