gyyang23 / AFPN

102 stars 12 forks source link

Help implementing in YOLOv5 #6

Open glenn-jocher opened 1 year ago

glenn-jocher commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks!

image

Anoue commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks!

image

I have reproduced code here, but it seems to be somewhat different from the author's. I can provide it to you if necessary.

wossg-999 commented 1 year ago

@Anoue Can you please give me the code to learn, just to learn. If you can, here is my email address:plause@126.com Thanks!

gyyang23 commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks!

image

Thank you! yolov5+afpn is implemented based on mmyolo. The code is as follows:


from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmyolo.registry import MODELS

from mmengine.model import BaseModule

def BasicConv(filter_in, filter_out, kernel_size, stride=1):
    pad = (kernel_size - 1) // 2 if kernel_size else 0
    return nn.Sequential(OrderedDict([
        ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
        ("bn", nn.BatchNorm2d(filter_out)),
        ("silu", nn.SiLU(inplace=True)),
    ]))

def Conv(filter_in, filter_out, kernel_size, stride=1, pad=0):
    return nn.Sequential(OrderedDict([
        ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
        ("bn", nn.BatchNorm2d(filter_out)),
        ("silu", nn.SiLU(inplace=True)),
    ]))

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, filter_in, filter_out):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(filter_in, filter_out, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(filter_out, momentum=0.1)
        self.silu = nn.SiLU(inplace=True)
        self.conv2 = nn.Conv2d(filter_out, filter_out, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(filter_out, momentum=0.1)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.silu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.silu(out)

        return out

class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super(Upsample, self).__init__()

        self.upsample = nn.Sequential(
            BasicConv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        )

    def forward(self, x, ):
        x = self.upsample(x)
        return x

class Downsample_x2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample_x2, self).__init__()

        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 2, 2)
        )

    def forward(self, x, ):
        x = self.downsample(x)

        return x

class Downsample_x4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample_x4, self).__init__()

        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 4, 4)
        )

    def forward(self, x, ):
        x = self.downsample(x)

        return x

class ASFF_2(nn.Module):
    def __init__(self, inter_dim=512):
        super(ASFF_2, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)

        self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, input1, input2):
        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :]

        out = self.conv(fused_out_reduced)

        return out

class ASFF_3(nn.Module):
    def __init__(self, inter_dim=512):
        super(ASFF_3, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)

        self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, input1, input2, input3):
        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)
        level_3_weight_v = self.weight_level_3(input3)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :] + \
                            input3 * levels_weight[:, 2:, :, :]

        out = self.conv(fused_out_reduced)

        return out

class ScaleBlockBody(nn.Module):
    def __init__(self, channels=[128, 256, 512]):
        super(ScaleBlockBody, self).__init__()

        self.blocks_top1 = nn.Sequential(
            BasicConv(channels[0], channels[0], 1),
        )
        self.blocks_mid1 = nn.Sequential(
            BasicConv(channels[1], channels[1], 1),
        )
        self.blocks_bot1 = nn.Sequential(
            BasicConv(channels[2], channels[2], 1),
        )

        self.downsample_top1_2 = Downsample_x2(channels[0], channels[1])
        self.upsample_mid1_2 = Upsample(channels[1], channels[0], scale_factor=2)

        self.asff_top1 = ASFF_2(inter_dim=channels[0])
        self.asff_mid1 = ASFF_2(inter_dim=channels[1])

        self.blocks_top2 = nn.Sequential(
            BasicBlock(channels[0], channels[0]),
            BasicBlock(channels[0], channels[0]),
            BasicBlock(channels[0], channels[0])
        )
        self.blocks_mid2 = nn.Sequential(
            BasicBlock(channels[1], channels[1]),
            BasicBlock(channels[1], channels[1]),
            BasicBlock(channels[1], channels[1])
        )

        self.downsample_top2_2 = Downsample_x2(channels[0], channels[1])
        self.downsample_top2_4 = Downsample_x4(channels[0], channels[2])
        self.downsample_mid2_2 = Downsample_x2(channels[1], channels[2])
        self.upsample_mid2_2 = Upsample(channels[1], channels[0], scale_factor=2)
        self.upsample_bot2_2 = Upsample(channels[2], channels[1], scale_factor=2)
        self.upsample_bot2_4 = Upsample(channels[2], channels[0], scale_factor=4)

        self.asff_top2 = ASFF_3(inter_dim=channels[0])
        self.asff_mid2 = ASFF_3(inter_dim=channels[1])
        self.asff_bot2 = ASFF_3(inter_dim=channels[2])

        self.blocks_top3 = nn.Sequential(
            BasicBlock(channels[0], channels[0]),
            BasicBlock(channels[0], channels[0]),
            BasicBlock(channels[0], channels[0])
        )
        self.blocks_mid3 = nn.Sequential(
            BasicBlock(channels[1], channels[1]),
            BasicBlock(channels[1], channels[1]),
            BasicBlock(channels[1], channels[1])
        )
        self.blocks_bot3 = nn.Sequential(
            BasicBlock(channels[2], channels[2]),
            BasicBlock(channels[2], channels[2]),
            BasicBlock(channels[2], channels[2])
        )

    def forward(self, x):
        x1, x2, x3 = x

        x1 = self.blocks_top1(x1)
        x2 = self.blocks_mid1(x2)
        x3 = self.blocks_bot1(x3)

        top = self.asff_top1(x1, self.upsample_mid1_2(x2))
        mid = self.asff_mid1(self.downsample_top1_2(x1), x2)

        x1 = self.blocks_top2(top)
        x2 = self.blocks_mid2(mid)

        top = self.asff_top2(x1, self.upsample_mid2_2(x2), self.upsample_bot2_4(x3))
        mid = self.asff_mid2(self.downsample_top2_2(x1), x2, self.upsample_bot2_2(x3))
        bot = self.asff_bot2(self.downsample_top2_4(x1), self.downsample_mid2_2(x2), x3)

        top = self.blocks_top3(top)
        mid = self.blocks_mid3(mid)
        bot = self.blocks_bot3(bot)

        return top, mid, bot

@MODELS.register_module()
class AFPN(BaseModule):
    def __init__(self, in_channels=[256, 512, 1024], out_channels=[256, 512, 1024]):
        super(AFPN, self).__init__()

        self.conv1 = BasicConv(in_channels[0], in_channels[0] // 4, 1)
        self.conv2 = BasicConv(in_channels[1], in_channels[1] // 4, 1)
        self.conv3 = BasicConv(in_channels[2], in_channels[2] // 4, 1)

        self.body = nn.Sequential(
            ScaleBlockBody([in_channels[0] // 4, in_channels[1] // 4, in_channels[2] // 4])
        )

        self.conv11 = BasicConv(in_channels[0] // 4, out_channels[0], 1)
        self.conv22 = BasicConv(in_channels[1] // 4, out_channels[1], 1)
        self.conv33 = BasicConv(in_channels[2] // 4, out_channels[2], 1)

        # ----------------------------------------------------------------#
        #   init weight
        # ----------------------------------------------------------------#
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight, gain=0.02)
            elif isinstance(m, nn.BatchNorm2d):
                torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
                torch.nn.init.constant_(m.bias.data, 0.0)

    def forward(self, x):
        x1, x2, x3 = x

        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x3 = self.conv3(x3)

        out1, out2, out3 = self.body([x1, x2, x3])

        out1 = self.conv11(out1)
        out2 = self.conv22(out2)
        out3 = self.conv33(out3)

        return tuple([out1, out2, out3])
athrunsunny commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks!

image

@glenn-jocher You can try it https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

glenn-jocher commented 1 year ago

@athrunsunny hey that looks like a great source, it's already all set up with YOLOv5 YAMLs and files.

wossg-999 commented 1 year ago

@gyyang23 Thanks a lot

Anoue commented 1 year ago
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()

        self.upsample = nn.Sequential(
            Conv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        )        
    def forward(self, x):
        x = self.upsample(x)
        return x

class Downsample_x2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 2, 2, 0)
        )

    def forward(self, x):
        x = self.downsample(x)
        return x

class Downsample_x4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 4, 4, 0)
        )

    def forward(self, x):
        x = self.downsample(x)
        return x

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, c1, c2):
        super().__init__()
        self.cv1 = nn.Conv2d(c1, c2, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(c2, momentum=0.1)
        self.act = nn.SiLU(inplace=True)
        self.cv2 = nn.Conv2d(c2, c2, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(c2, momentum=0.1)

    def forward(self, x):
        residual = x

        x = self.cv1(x)
        x = self.bn1(x)
        x = self.act(x)

        x = self.cv2(x)
        x = self.bn2(x)

        x += residual
        x = self.act(x)

        return x

class ASFF_2(nn.Module):
    def __init__(self, c1, c2, level=0):
        super().__init__()
        c1_l, c1_h = c1[0], c1[1]
        self.level = level
        self.dim = [
            c1_l,
            c1_h
        ]
        self.inter_dim = self.dim[self.level]
        compress_c = 8

        if level == 0:
            self.stride_level_1 = Upsample(c1_h, self.inter_dim)
        if level == 1:
            self.stride_level_0 = Downsample_x2(c1_l, self.inter_dim)

        self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weights_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)
        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, x):
        x_level_0, x_level_1 = x[0], x[1]

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
        elif self.level == 1:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = x_level_1

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v), 1)
        levels_weight = self.weights_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                            level_1_resized * levels_weight[:, 1:2, :, :]
        out = self.conv(fused_out_reduced)

        return out

class ASFF_3(nn.Module):
    def __init__(self, c1, c2, level=0):
        super().__init__()
        c1_l, c1_m, c1_h = c1[0], c1[1], c1[2]
        self.level = level
        self.dim = [
            c1_l,
            c1_m,
            c1_h
        ]
        self.inter_dim = self.dim[self.level]
        compress_c = 8

        if level == 0:
            self.stride_level_1 = Upsample(c1_m, self.inter_dim)
            self.stride_level_2 = Upsample(c1_h, self.inter_dim, scale_factor=4)

        if level == 1:
            self.stride_level_0 = Downsample_x2(c1_l, self.inter_dim)
            self.stride_level_2 = Upsample(c1_h, self.inter_dim)

        if level == 2:
            self.stride_level_0 = Downsample_x4(c1_l, self.inter_dim)
            self.stride_level_1 = Downsample_x2(c1_m, self.inter_dim)

        self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weights_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, x):
        x_level_0, x_level_1, x_level_2 = x[0], x[1], x[2]

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_resized = self.stride_level_2(x_level_2)

        elif self.level == 1:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)

        elif self.level == 2:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_resized = x_level_2

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)

        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weights_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                            level_1_resized * levels_weight[:, 1:2, :, :] + \
                            level_2_resized * levels_weight[:, 2:, :, :]

        out = self.conv(fused_out_reduced)

        return out

# YOLOv5 ๐Ÿš€ by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  320
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4   160
   [-1, 3, C3, [128]],     #160
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  80
   [-1, 6, C3, [256]],  #80
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  40
   [-1, 9, C3, [512]],   #40
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32  20
   [-1, 3, C3, [1024]],  #20
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[4, 1, Conv, [64, 1, 1]], # 80 80
  [6, 1, Conv, [128, 1, 1]],  # 40 40
  [9, 1, Conv, [256, 1, 1]],  #20 20

  [10, 1, Conv, [64, 1, 1]],  # 13 80 80
  [11, 1, Conv, [128, 1, 1]], # 14 40 40
  [12, 1, Conv, [256, 1, 1]], # 15 20 20

  [[13, 14], 1, ASFF_2, [64, 0]],  #  16 80 80 
  [[13, 14], 1, ASFF_2, [128, 1]],  #  17 40 40

  [16, 1, BasicBlock, [64]],  # 18 80 80
  [17, 1, BasicBlock, [128]], # 19 40 40

  [[18, 19, 15], 1, ASFF_3, [64, 0]],  # 20
  [[18, 19, 15], 1, ASFF_3, [128, 1]], # 21
  [[18, 19, 15], 1, ASFF_3, [256, 2]], # 22

  [20, 1, BasicBlock, [64]],  #23
  [21, 1, BasicBlock, [128]], #24
  [22, 1, BasicBlock, [256]], #25

  [23, 1, Conv, [256, 1, 1]], #26
  [24, 1, Conv, [512, 1, 1]], #27
  [25, 1, Conv, [1024, 1, 1]], #28

  [[26, 27, 28], 1, Detect, [nc, anchors]],
  ]

if m in [Conv, GhostConv, BasicBlock]:

elif m is ASFF_2:
            c1, c2 = [ch[f[0]], ch[f[1]]], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]
elif m is ASFF_3:
            c1, c2 = [ch[f[0]], ch[f[1]], ch[f[2]]], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]
wossg-999 commented 1 year ago
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()

        self.upsample = nn.Sequential(
            Conv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        )        
    def forward(self, x):
        x = self.upsample(x)
        return x

class Downsample_x2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 2, 2, 0)
        )

    def forward(self, x):
        x = self.downsample(x)
        return x

class Downsample_x4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, 4, 4, 0)
        )

    def forward(self, x):
        x = self.downsample(x)
        return x

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, c1, c2):
        super().__init__()
        self.cv1 = nn.Conv2d(c1, c2, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(c2, momentum=0.1)
        self.act = nn.SiLU(inplace=True)
        self.cv2 = nn.Conv2d(c2, c2, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(c2, momentum=0.1)

    def forward(self, x):
        residual = x

        x = self.cv1(x)
        x = self.bn1(x)
        x = self.act(x)

        x = self.cv2(x)
        x = self.bn2(x)

        x += residual
        x = self.act(x)

        return x

class ASFF_2(nn.Module):
    def __init__(self, c1, c2, level=0):
        super().__init__()
        c1_l, c1_h = c1[0], c1[1]
        self.level = level
        self.dim = [
            c1_l,
            c1_h
        ]
        self.inter_dim = self.dim[self.level]
        compress_c = 8

        if level == 0:
            self.stride_level_1 = Upsample(c1_h, self.inter_dim)
        if level == 1:
            self.stride_level_0 = Downsample_x2(c1_l, self.inter_dim)

        self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weights_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)
        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, x):
        x_level_0, x_level_1 = x[0], x[1]

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
        elif self.level == 1:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = x_level_1

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v), 1)
        levels_weight = self.weights_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                            level_1_resized * levels_weight[:, 1:2, :, :]
        out = self.conv(fused_out_reduced)

        return out

class ASFF_3(nn.Module):
    def __init__(self, c1, c2, level=0):
        super().__init__()
        c1_l, c1_m, c1_h = c1[0], c1[1], c1[2]
        self.level = level
        self.dim = [
            c1_l,
            c1_m,
            c1_h
        ]
        self.inter_dim = self.dim[self.level]
        compress_c = 8

        if level == 0:
            self.stride_level_1 = Upsample(c1_m, self.inter_dim)
            self.stride_level_2 = Upsample(c1_h, self.inter_dim, scale_factor=4)

        if level == 1:
            self.stride_level_0 = Downsample_x2(c1_l, self.inter_dim)
            self.stride_level_2 = Upsample(c1_h, self.inter_dim)

        if level == 2:
            self.stride_level_0 = Downsample_x4(c1_l, self.inter_dim)
            self.stride_level_1 = Downsample_x2(c1_m, self.inter_dim)

        self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weights_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)

    def forward(self, x):
        x_level_0, x_level_1, x_level_2 = x[0], x[1], x[2]

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_resized = self.stride_level_2(x_level_2)

        elif self.level == 1:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)

        elif self.level == 2:
            level_0_resized = self.stride_level_0(x_level_0)
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_resized = x_level_2

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)

        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weights_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                            level_1_resized * levels_weight[:, 1:2, :, :] + \
                            level_2_resized * levels_weight[:, 2:, :, :]

        out = self.conv(fused_out_reduced)

        return out

# YOLOv5 ๐Ÿš€ by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  320
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4   160
   [-1, 3, C3, [128]],     #160
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  80
   [-1, 6, C3, [256]],  #80
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  40
   [-1, 9, C3, [512]],   #40
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32  20
   [-1, 3, C3, [1024]],  #20
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[4, 1, Conv, [64, 1, 1]], # 80 80
  [6, 1, Conv, [128, 1, 1]],  # 40 40
  [9, 1, Conv, [256, 1, 1]],  #20 20

  [10, 1, Conv, [64, 1, 1]],  # 13 80 80
  [11, 1, Conv, [128, 1, 1]], # 14 40 40
  [12, 1, Conv, [256, 1, 1]], # 15 20 20

  [[13, 14], 1, ASFF_2, [64, 0]],  #  16 80 80 
  [[13, 14], 1, ASFF_2, [128, 1]],  #  17 40 40

  [16, 1, BasicBlock, [64]],  # 18 80 80
  [17, 1, BasicBlock, [128]], # 19 40 40

  [[18, 19, 15], 1, ASFF_3, [64, 0]],  # 20
  [[18, 19, 15], 1, ASFF_3, [128, 1]], # 21
  [[18, 19, 15], 1, ASFF_3, [256, 2]], # 22

  [20, 1, BasicBlock, [64]],  #23
  [21, 1, BasicBlock, [128]], #24
  [22, 1, BasicBlock, [256]], #25

  [23, 1, Conv, [256, 1, 1]], #26
  [24, 1, Conv, [512, 1, 1]], #27
  [25, 1, Conv, [1024, 1, 1]], #28

  [[26, 27, 28], 1, Detect, [nc, anchors]],
  ]

if m in [Conv, GhostConv, BasicBlock]:

elif m is ASFF_2:
            c1, c2 = [ch[f[0]], ch[f[1]]], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]
elif m is ASFF_3:
            c1, c2 = [ch[f[0]], ch[f[1]], ch[f[2]]], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]

nice work!

nyj-ocean commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks!

image

@glenn-jocher Have you implemented the AFPN for YOLOv5 or YOLOv8 ?

nyj-ocean commented 1 year ago

@gyyang23 thanks your great work ! Based on the code you provided (https://github.com/gyyang23/AFPN/issues/6#issuecomment-1621982852), I attempted to replicate YOLOv5+AFPN in the MMYOLO framework.

(1) I copy your AFPN code (https://github.com/gyyang23/AFPN/issues/6#issuecomment-1621982852), and created a new APFN.py file in mmyolo-main/mmyolo/models/necks/

AFPN.txt

(2) I change the neck from YOLOv5PAFPN to AFPN in mmyolo-main/configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py file yolov5_s-v61_syncbn_8xb16-300e_coco-AFPN.txt

(3) I run python tools/train.py configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py but I get the following error:

Traceback (most recent call last):
  File "tools/train.py", line 117, in <module>
    main()
  File "tools/train.py", line 113, in main
    runner.train()
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1735, in train
    model = self.train_loop.run()  # type: ignore
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/runner/loops.py", line 96, in run
    self.run_epoch()
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/runner/loops.py", line 112, in run_epoch
    self.run_iter(idx, data_batch)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/runner/loops.py", line 128, in run_iter
    outputs = self.runner.model.train_step(
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 340, in _run_forward
    results = self(**data, mode=mode)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 92, in forward
    return self.loss(inputs, data_samples)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmdet/models/detectors/single_stage.py", line 77, in loss
    x = self.extract_feat(batch_inputs)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/mmdet/models/detectors/single_stage.py", line 148, in extract_feat
    x = self.neck(x)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/S/nyj/mm-main/mm/models/necks/AFPN.py", line 271, in forward
    x1 = self.conv1(x1)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/yjn/.conda/envs/mm/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 256, 1, 1], expected input[4, 128, 80, 80] to have 256 channels, but got 128 channels instead

How can I resolve this error?

How to determine the in_channels and out_channels after changing the neck from YOLOv8PAFPN to AFPN?

glenn-jocher commented 1 year ago

@gyyang23 great work! We are interested in implementing AFPN for YOLOv5 and YOLOv8. Do you have the YOLOv5 code you used for the training here or a detailed architecture diagram we could use to create the torch modules? Thanks! image

@glenn-jocher Have you implemented the AFPN for YOLOv5 or YOLOv8 ?

@nyj-ocean I'm going to try to create an AFPN PR this week in https://github.com/ultralytics/ultralytics based on https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

If you'd like to help that would be great!

EDIT: PR opened in https://github.com/ultralytics/ultralytics/pull/3612

glenn-jocher commented 1 year ago

@Anoue thanks for the YOLOv5 code, but I have a question. Your YAML is different than the one from

https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

Screenshot 2023-07-09 at 15 17 56

Which one is a more correct implementation of the paper?

Anoue commented 1 year ago

@Anoue thanks for the YOLOv5 code, but I have a question. Your YAML is different than the one from

https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501 Screenshot 2023-07-09 at 15 17 56

Which one is a more correct implementation of the paper? The source code provided by the author does not use C3 structure, but instead uses continuous 1x1 convolutions for operation. My code was completely modified based on the author's source code.

Anoue commented 1 year ago

@Anoue thanks for the YOLOv5 code, but I have a question. Your YAML is different than the one from

https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501 Screenshot 2023-07-09 at 15 17 56

Which one is a more correct implementation of the paper?

Sorry, I just found a new bug, the author used BasicBlock three times for the output of ASFF_3, and I only used it once. This is the improved code, please see https://github.com/Anoue/yolov5s/blob/main/AFPN.py

athrunsunny commented 1 year ago

@Anoue thanks for the YOLOv5 code, but I have a question. Your YAML is different than the one from

https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501 Screenshot 2023-07-09 at 15 17 56

Which one is a more correct implementation of the paper?

Original author uses continuous 1x1 convolutions 4 times after ASFF_x .This code retained the original yolov5 network structure. If you want to use the original author's method, you can replace the C3 structure with four 1 * 1 convolutions.

glenn-jocher commented 1 year ago

@Anoue ok got it. I ran some experiments with your original YAML with YOLOv8s, but got lower mAP than the default. I'll try the new YAML configuration. Even though YOLOv8 is different than YOLOv5 I think if AFPN works for one it should work for the other also.

Screenshot 2023-07-10 at 11 50 13
glenn-jocher commented 1 year ago

@gyyang23 I was looking at your mmyolo code. Is your AFPN() module supposed to essentially replace the YOLOv5 head part here and output straight to Detect()? i.e. for me to reproduce this in the YOLOv5 YAML should the head basically be:

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head with AFPN update
head:
   [[[4, 6, 9], 1, AFPN, [256, 512, 1024]],  # mmyolo AFPN module
    [-1], 1, Detect, [nc]]  # Detect(P3, P4, P5)
   ]
nyj-ocean commented 1 year ago

@glenn-jocher @athrunsunny @Anoue https://github.com/gyyang23/AFPN/issues/6#issuecomment-1622961551 https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

I refer to the two codes of AFPN above links, and try to train YOLOv8+AFPN on my dataset. The experiment results are following: all experiment are conducted on YOLOv8 framework https://github.com/ultralytics/ultralytics

method | P | R | mAP50 | Author | link -- | -- | -- | -- | -- | -- v8m | 0.723 | 0.747 | 0.765 | glenn-jocher | https://github.com/ultralytics/ultralytics v8m-AFPN-github | 0.755 | 0.731 | 0.768 | Anoue | https://github.com/gyyang23/AFPN/issues/6#issuecomment-1622961551 v8m-AFPN-CSDN | 0.715 | 0.735 | 0.758 | athrunsunny | https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501
Anoue commented 1 year ago

@glenn-jocher @athrunsunny @Anoue #6 (comment) https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

I refer to the two codes of AFPN above links, and try to train YOLOv8+AFPN on my dataset. The experiment results are following: all experiment are conducted on YOLOv8 framework https://github.com/ultralytics/ultralytics

method P R mAP50 Author link v8m 0.723 0.747 0.765 glenn-jocher https://github.com/ultralytics/ultralytics v8m-AFPN-github 0.755 0.731 0.768 Anoue #6 (comment) v8m-AFPN-CSDN 0.715 0.735 0.758 athrunsunny https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

Wow, that part of me looks alright. How many pictures and categories does your data set have? Is there a big difference compared to COCO? How much is Epoch?

nyj-ocean commented 1 year ago

@Anoue my dataset: 1 class, 4323 images, train : val = 8 : 2 experiment setting: 150 epoch, all other parameters are the same, using the default parameters of yolov8

nyj-ocean commented 1 year ago

@Anoue In your improved code https://github.com/Anoue/yolov5s/blob/main/AFPN.py , the number of BasicBlock_n is 9, is it right? why not 3 ?

#   [20, 9, BasicBlock_n, [64]],  #23
#   [21, 9, BasicBlock_n, [128]], #24
#   [22, 9, BasicBlock_n, [256]], #25
Anoue commented 1 year ago

@Anoue In your improved code https://github.com/Anoue/yolov5s/blob/main/AFPN.py , the number of BasicBlock_n is 9, is it right? why not 3 ?

#   [20, 9, BasicBlock_n, [64]],  #23
#   [21, 9, BasicBlock_n, [128]], #24
#   [22, 9, BasicBlock_n, [256]], #25

Because multiplied by depth_multiple

Anoue commented 1 year ago

@Anoue I can not run your improved code.

error๏ผš

Traceback (most recent call last):
  File "train.py", line 216, in <module>
    train()
  File "train.py", line 212, in train
    trainer.train()
  File "F:\UserData\nj\ultralytics-main\ultralytics\yolo\engine\trainer.py", line 191, in train
    self._do_train(world_size)
  File "F:\UserData\nj\ultralytics-main\ultralytics\yolo\engine\trainer.py", line 266, in _do_train
    self._setup_train(world_size)
  File "F:\UserData\nj\ultralytics-main\ultralytics\yolo\engine\trainer.py", line 205, in _setup_train
    ckpt = self.setup_model()
  File "F:\UserData\nj\ultralytics-main\ultralytics\yolo\engine\trainer.py", line 436, in setup_model
    self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
  File "train.py", line 61, in get_model
    model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
  File "F:\UserData\nj\ultralytics-main\ultralytics\nn\tasks.py", line 185, in __init__
    self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
  File "F:\UserData\nj\ultralytics-main\ultralytics\nn\tasks.py", line 485, in parse_model
    c2 = ch[f]
TypeError: list indices must be integers or slices, not list

tasks.py :

        if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
                 BottleneckCSP, C1, C2, C2f, C3, C3TR, nn.ConvTranspose2d, DWConvTranspose2d, C3x, BasicBlock_n): 
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]
            if m in (BottleneckCSP, BasicBlock_n,C1, C2, C2f, C3, C3TR, C3Ghost, C3x, C3STR):
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]

        #AFPN 
        elif m is ASFF_2:
            c1, c2 = [ch[f[0]], ch[f[1]]], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]
        elif m is ASFF_3:
            c1, c2 = [ch[f[0]], ch[f[1]], ch[f[2]]], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

        else:
            c2 = ch[f]       ##########  485

How many classes did you use for training, has the nc been modified? And depth_multiple=0.33, width_multiple=0.50

nyj-ocean commented 1 year ago

@Anoue sorry, the reason I can not run your improved code is that I missed some codes. Actually, there are no problems with your improved codes. It can run.

By the way, if depth_multiple=0.67, what should the number of BasicBlock_n be set to ๏ผŸ 2 and 5 ?

  - [16, 2, BasicBlock_n, [64]]  # 18 80 80
  - [17, 2, BasicBlock_n, [128]] # 19 40 40

  - [[18, 19, 15], 1, ASFF_3, [64, 0]]  # 20
  - [[18, 19, 15], 1, ASFF_3, [128, 1]] # 21
  - [[18, 19, 15], 1, ASFF_3, [256, 2]] # 22

  - [20, 5, BasicBlock_n, [64]]  #23
  - [21, 5, BasicBlock_n, [128]] #24
  - [22, 5, BasicBlock_n, [256]] #25
Anoue commented 1 year ago

@Anoue sorry, the reason I can not run your improved code is that I missed some codes. Actually, there are no problems with your improved codes. It can run.

By the way, if depth_multiple=0.67, what should the number of BasicBlock_n be set to ๏ผŸ 2 and 5 ?

  - [16, 2, BasicBlock_n, [64]]  # 18 80 80
  - [17, 2, BasicBlock_n, [128]] # 19 40 40

  - [[18, 19, 15], 1, ASFF_3, [64, 0]]  # 20
  - [[18, 19, 15], 1, ASFF_3, [128, 1]] # 21
  - [[18, 19, 15], 1, ASFF_3, [256, 2]] # 22

  - [20, 5, BasicBlock_n, [64]]  #23
  - [21, 5, BasicBlock_n, [128]] #24
  - [22, 5, BasicBlock_n, [256]] #25

For the s version, each ASFF_3 output requires 3 BasicBlocks to process. The author's source code only released the s version, so I don't know if all the versions are processed with only 3 BasicBlocks. Or like C3, there are different numbers of C3 for different versions.

nyj-ocean commented 1 year ago

@glenn-jocher

The mAP calculated when the last epoch of training is completed does not match the mAP calculated using the python val.py command.

(1) The mAP calculated when the last epoch of training is completed is 0.784

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
    150/150      19.1G     0.8093     0.5067      1.132          4       1280: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 289/289 [03
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.766      0.733      0.783      0.464

Optimizer stripped from runs\detect\v8m-AFPN\weights\best.pt, 39.8MB

Validating runs\detect\v8m-AFPN\weights\best.pt...
YOLOv8m-AFPN summary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.753      0.752      0.784      0.464
Speed: 0.6ms preprocess, 6.9ms inference, 0.0ms loss, 1.0ms postprocess per image

(2) the mAP calculated using the python val.py command is 0.765 (using the best.pt)

python val.py
YOLOv8m-AFPNsummary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
val: Scanning F:\UserData2\nyj\fence-data-set\val\labels.cache... 865 images, 0 backgrounds, 0 val: Scanning F:\UserData2\nj\data-set\val\labels.cache... 865 images, 0 backgrounds, 0 corrupt: 100
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.728      0.753      0.765      0.438

why the two mAP are different? which map is the correct map ? 0.784 or 0.765?

Anoue commented 1 year ago

@glenn-jocher

The mAP calculated when the last epoch of training is completed does not match the mAP calculated using the python val.py command.

(1) The mAP calculated when the last epoch of training is completed is 0.784

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
    150/150      19.1G     0.8093     0.5067      1.132          4       1280: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 289/289 [03
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.766      0.733      0.783      0.464

Optimizer stripped from runs\detect\v8m-AFPN\weights\best.pt, 39.8MB

Validating runs\detect\v8m-AFPN\weights\best.pt...
YOLOv8m-AFPN summary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.753      0.752      0.784      0.464
Speed: 0.6ms preprocess, 6.9ms inference, 0.0ms loss, 1.0ms postprocess per image

(2) the mAP calculated using the python val.py command is 0.765 (using the best.pt)

python val.py
YOLOv8m-AFPNsummary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
๏ฟฝ[34m๏ฟฝ[1mval: ๏ฟฝ[0mScanning F:\UserData2\nyj\fence-data-set\val\labels.cache... 865 images, 0 backgrounds, 0 ๏ฟฝ[34m๏ฟฝ[1mval: ๏ฟฝ[0mScanning F:\UserData2\nj\data-set\val\labels.cache... 865 images, 0 backgrounds, 0 corrupt: 100๏ฟฝ[0m
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ|
                   all        865       1685      0.728      0.753      0.765      0.438

why the two mAP are different? which map is the correct map ? 0.784 or 0.765?

Is the data set you use when doing Python val.py the default data set or the test set when you train?

nyj-ocean commented 1 year ago

@Anoue

(1) In data.yaml :

#path:  # dataset root dir
train: F:\UserData\nj\person-data-set\train\images  # train images (relative to 'path') 128 images
val:   F:\UserData\nj\person-data-set\val\images  # val images (relative to 'path') 128 images

names:
  0: person

(2) In default.yaml

# Val/Test settings ----------------------------------------------------------------------------------------------------
val: True  # validate/test during training
split: val  # dataset split to use for validation, i.e. 'val', 'test' or 'train'
save_json: False  #True # False  # save results to JSON file
save_hybrid: False  # save hybrid version of labels (labels + additional predictions)
conf:  # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7  # intersection over union (IoU) threshold for NMS

As you can see in https://github.com/gyyang23/AFPN/issues/6#issuecomment-1631808644, when calculating the mAP in training and using python val.py, the Class Images Instances are the same.

YOLOv8m-AFPN summary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
                 Class     Images  Instances     
                   all        865       1685 
Anoue commented 1 year ago

@Anoue

* I don't use `test` set. I only use `train` and `val` set.

(1) In data.yaml :

#path:  # dataset root dir
train: F:\UserData\nj\person-data-set\train\images  # train images (relative to 'path') 128 images
val:   F:\UserData\nj\person-data-set\val\images  # val images (relative to 'path') 128 images

names:
  0: person

(2) In default.yaml

# Val/Test settings ----------------------------------------------------------------------------------------------------
val: True  # validate/test during training
split: val  # dataset split to use for validation, i.e. 'val', 'test' or 'train'
save_json: False  #True # False  # save results to JSON file
save_hybrid: False  # save hybrid version of labels (labels + additional predictions)
conf:  # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7  # intersection over union (IoU) threshold for NMS
* I use the same `val set` in `training` and `python val.py`.

As you can see in #6 (comment), when calculating the mAP in training and using python val.py, the Class Images Instances are the same.

YOLOv8m-AFPN summary: 339 layers, 19660360 parameters, 0 gradients, 62.8 GFLOPs
                 Class     Images  Instances     
                   all        865       1685 

I haven't used YOLOv8, so I don't know whether the --data parameter is needed when running val.py. If necessary, do you set it to the data.yaml you used for training

gyyang23 commented 1 year ago

@gyyang23 I was looking at your mmyolo code. Is your AFPN() module supposed to essentially replace the YOLOv5 head part here and output straight to Detect()? i.e. for me to reproduce this in the YOLOv5 YAML should the head basically be:

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head with AFPN update
head:
   [[[4, 6, 9], 1, AFPN, [256, 512, 1024]],  # mmyolo AFPN module
    [-1], 1, Detect, [nc]]  # Detect(P3, P4, P5)
   ]

Sorry, I don't understand 'output straight to Detect()'. I only changed the neck part of yolov5, leaving the rest unchanged. I did not modify the yolov5 head part.

TstMua commented 1 year ago

@glenn-jocher @athrunsunny @Anoue #6 (comment) https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

I refer to the two codes of AFPN above links, and try to train YOLOv8+AFPN on my dataset. The experiment results are following: all experiment are conducted on YOLOv8 framework https://github.com/ultralytics/ultralytics

method P R mAP50 Author link v8m 0.723 0.747 0.765 glenn-jocher https://github.com/ultralytics/ultralytics v8m-AFPN-github 0.755 0.731 0.768 Anoue #6 (comment) v8m-AFPN-CSDN 0.715 0.735 0.758 athrunsunny https://blog.csdn.net/athrunsunny/article/details/131566311?spm=1001.2014.3001.5501

May I ask if you can provide your modified yolov8.yaml?

XadBo commented 1 year ago

@glenn-jocher @Anoue @athrunsunny I trained YOLOv5+AFPN(https://github.com/Anoue/yolov5s/blob/main/AFPN.py) on the VisDrone dataset. But the result is not well.

methods P R map50 map50-95
YOLOv5 0.436 0.339 0.328 0.176
YOLOv5+AFPN 0.405 0.312 0.303 0.16

AFPN performers better results on coco, I think it should work for VisDrone. Is the AFPN not applicable to small targets or something wrong in the training?

xlnn commented 10 months ago

16, 1, BasicBlock, [64]], # 18 80 80 [17, 1, BasicBlock, [128]], # 19 40 40

Hello, I have the same problems, I modify yolov7.yaml. As follows:

head:
  [[-1, 1, SPPCSPC, [512]], # 51

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [37, 1, Conv, [256, 1, 1]], # route backbone P4
   [[-1, -2], 1, Concat, [1]], # 55

   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]], # 58
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]], # 63

   [ [ 61, 63 ], 1, ASFF_2, [ 128, 0 ] ], #64
   [ [ 61, 63 ], 1, ASFF_2, [ 256, 1 ] ],

   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [24, 1, Conv, [128, 1, 1]], # route backbone P3
   [[-1, -2], 1, Concat, [1]], # 69

   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1]], # 77

   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 66], 1, Concat, [1]],

   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]], # 90

   [ [ 77, 90, 63 ], 1, ASFF_3, [ 64, 0 ] ], #
   [ [ 77, 90, 63 ], 1, ASFF_3, [ 128, 1 ] ], #
   [ [ 77, 90, 63 ], 1, ASFF_3, [ 256, 2 ] ], #

   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3, 51], 1, Concat, [1]],

   [-1, 1, Conv, [512, 1, 1]],
   [-2, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]], # 106

   [77, 1, RepConv, [256, 3, 1]],
   [90, 1, RepConv, [512, 3, 1]],
   [106, 1, RepConv, [1024, 3, 1]],

   [[107,108,109], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

The problems are:

/home/class1/.local/lib/python3.8/site-packages/torch/nn/functional.py:3454: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  warnings.warn(
Traceback (most recent call last):
  File "/home/class1/work/modify/yolov7_AFP/models/yolo.py", line 834, in <module>
    model = Model(opt.cfg).to(device)
  File "/home/class1/work/modify/yolov7_AFP/models/yolo.py", line 544, in __init__
    m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
  File "/home/class1/work/modify/yolov7_AFP/models/yolo.py", line 599, in forward
    return self.forward_once(x, profile)  # single-scale inference, train
  File "/home/class1/work/modify/yolov7_AFP/models/yolo.py", line 625, in forward_once
    x = m(x)  # run
  File "/home/class1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/class1/work/modify/yolov7_AFP/models/common.py", line 2124, in forward
    levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 16 and 32 in dimension 2 (The offending index is 1)

Process finished with exit code 1

Can you know the problems? Thank you!