Hope1337 / YOWOv3

46 stars 7 forks source link

Some doubts about restarting training #30

Open T-wow opened 1 month ago

T-wow commented 1 month ago

def fuse_conv(conv, norm): fused_conv = torch.nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True).requiresgrad(False).to(conv.weight.device)

w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))

b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)

return fused_conv

class Conv(torch.nn.Module):
def __init__(self, in_ch, out_ch, k=1, s=1, p=None, d=1, g=1):
    super().__init__()
    self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, pad(k, p, d), d, g, False)
    self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.03)
    self.relu = torch.nn.SiLU(inplace=True)

def forward(self, x):
    return self.relu(self.norm(self.conv(x)))

def fuse_forward(self, x):
    return self.relu(self.conv(x))

  class YOLO(torch.nn.Module):
def __init__(self, width, depth, pretrain_path):
    super().__init__()
    self.net = DarkNet(width, depth)
    self.fpn = DarkFPN(width, depth)

    self.pretrain_path = pretrain_path

def forward(self, x):
    x = self.net(x)
    return self.fpn(x)

def fuse(self):
    for m in self.modules():
        if type(m) is Conv and hasattr(m, 'norm'):
            m.conv = fuse_conv(m.conv, m.norm)
            m.forward = m.fuse_forward
            delattr(m, 'norm')
    return self

    Due to the mentioned def fuse_conv() and def fuse(self), if I restart training, since the Conv class in the file will be used, fusion operations as well as freezing of weights will take place. Wouldn't this mean that during the process of restarting training, the network would not get trained effectively?"
    Thanks
Hope1337 commented 1 month ago

@T-wow Fusion batchnorm only occurs in inference stage which was removed (by some reason). You can recover the fusion process to speed up inference, it will not affects anything in training stage.

T-wow commented 1 month ago

@Hope1337 Thank you for your response. I have another question as well. In your program, is there a setting for freezing weights during training?

Hope1337 commented 1 month ago

@T-wow yes, I added it few day ago. Look at new config file to see