WongKinYiu / yolov7

Implementation of paper - YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors
GNU General Public License v3.0
13.44k stars 4.23k forks source link

Fuse ImplicitA and Convolution #441

Open ghost opened 2 years ago

ghost commented 2 years ago

建堯博士您好,關於在detect head的 fuse layer有試跑過一次有跳出問題,因此針對部分有重新實作了一次。 請您確認一下是否正確。

def fuse_conv_and_ia(conv, ia):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    c1, c2, _, _ = conv.weight.shape
    c1_, c2_, _, _ = ia.implicit.shape

    w_conv = conv.weight.clone().reshape(c1, c2)
    b_conv = conv.bias.clone()
    w_ia = ia.implicit.clone().reshape(c2_, c1_)

    fusedconv.bias.copy_(nn.Parameter(torch.matmul(w_conv, w_ia).squeeze(1) + b_conv))

    return fusedconv

def fuse_conv_and_im(conv, im):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    c1, c2, _, _ = im.implicit.shape
    w_conv = conv.weight.clone()
    b_conv = conv.bias.clone()

    w1_im = im.implicit.clone().reshape(c2)
    w2_im = im.implicit.clone().transpose(0, 1)
    fusedconv.bias.copy_(nn.Parameter(b_conv * w1_im))
    fusedconv.weight.copy_(nn.Parameter(w_conv * w2_im))
    return fusedconv
WongKinYiu commented 2 years ago

New version provide more concise fuse function, you could check it.