AlexeyAB / darknet

YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )
http://pjreddie.com/darknet/
Other
21.7k stars 7.96k forks source link

ASFF - Learning Spatial Fusion for Single-Shot Object Detection - 63% mAP@0.5 with 45.5FPS #4382

Closed Kyuuki93 closed 3 years ago

Kyuuki93 commented 4 years ago

Learning Spatial Fusion for Single-Shot Object Detection

image image image

@AlexeyAB it's seems worth to take a look

AlexeyAB commented 4 years ago

ASFF significantly improves the box AP from 38.8% to 40.6% as shown in Table 3.

image


Also there are used:

  1. BoF (MixUp, ...) - +4.2 mAP@0.5...0.95, but +0 mAP@0.5 and +5.6% AP@70: https://github.com/AlexeyAB/darknet/issues/3272

  2. MegDet: A Large Mini-Batch Object Detector (synchronized batch normalization technique) - mAP 52.5%) to COCO 2017 Challenge, where we won the 1st place of Detection task: https://arxiv.org/abs/1711.07240v4 - issue: https://github.com/AlexeyAB/darknet/issues/4386

  3. Dropblock + Receptive field block gives +1.7% AP@0.5...0.95

  4. So ASFF gives only +1.8% AP@0.5...0.95 and 1.5% AP@0.5 and 2.5% AP@.07

  5. cosine learning rate: https://github.com/AlexeyAB/darknet/pull/2651

Kyuuki93 commented 4 years ago

This paper is a bit confusing, so I took a look at his code, his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax

image

In shortly, ASFF mapping the inputsx0, x1, x2 of yolo0, yolo1, yolo2 to each other to enhance the detection, but I still wonder, which layers output respond to x0, x1, x2

AlexeyAB commented 4 years ago

@Kyuuki93

his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax

Can you provide link to these lines of code?

Kyuuki93 commented 4 years ago

https://github.com/ruinmessi/ASFF/blob/master/models/network_blocks.py

calc weights:

image

weights func:

image

add_conv func:

image
AlexeyAB commented 4 years ago

@Kyuuki93

This paper is a bit confusing, so I took a look at his code, his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax

image

This formula seems to be softmax a = exp(x1) / (exp(x1) + exp(x2) + exp(x3)) https://en.wikipedia.org/wiki/Softmax_function

I added fixes to implement ASFF and BiFPN (from EfficientDet): https://github.com/AlexeyAB/darknet/issues/3772#issuecomment-559592123


In shortly, ASFF mapping the inputsx0, x1, x2 of yolo0, yolo1, yolo2 to each other to enhance the detection, but I still wonder, which layers output respond to x0, x1, x2?

It seems layers: 17, 24, 32

https://github.com/ruinmessi/ASFF/blob/c74e08591b2756e5f773892628dd9a6d605f4b77/models/yolov3_asff.py#L142

https://github.com/ruinmessi/ASFF/blob/c74e08591b2756e5f773892628dd9a6d605f4b77/models/yolov3_asff.py#L129

isgursoy commented 4 years ago

waiting for improvements, good things happening here

Kyuuki93 commented 4 years ago

This formula seems to be softmax a = exp(x1) / (exp(x1) + exp(x2) + exp(x3)) https://en.wikipedia.org/wiki/Softmax_function

Yeah, I got it, his fusion was finished by 1x1 conv, softmax and sum.

I added fixes to implement ASFF and BiFPN (from EfficientDet): #3772 (comment)

I will try to implement ASFF, BiFPN module and run some tests

Kyuuki93 commented 4 years ago

For up-sampling, we first apply a 1x1 convolution layer to compress the number f channels of features to that in level l, and then upscale the resolutions respectively with interpolation.

@AlexeyAB How to implement this upscale in .cfg file?

AlexeyAB commented 4 years ago

@Kyuuki93 [upsample] layer with stride=2 or stride=4

Kyuuki93 commented 4 years ago
   layer   filters  size/strd(dil)      input                output
   0 conv     32       3 x 3/ 1    416 x 416 x   3 ->  416 x 416 x  32 0.299 BF
   1 conv     64       3 x 3/ 2    416 x 416 x  32 ->  208 x 208 x  64 1.595 BF
   2 conv     32       1 x 1/ 1    208 x 208 x  64 ->  208 x 208 x  32 0.177 BF
   3 conv     64       3 x 3/ 1    208 x 208 x  32 ->  208 x 208 x  64 1.595 BF
   4 Shortcut Layer: 1
   5 conv    128       3 x 3/ 2    208 x 208 x  64 ->  104 x 104 x 128 1.595 BF
   6 conv     64       1 x 1/ 1    104 x 104 x 128 ->  104 x 104 x  64 0.177 BF
   7 conv    128       3 x 3/ 1    104 x 104 x  64 ->  104 x 104 x 128 1.595 BF
   8 Shortcut Layer: 5
   9 conv     64       1 x 1/ 1    104 x 104 x 128 ->  104 x 104 x  64 0.177 BF
  10 conv    128       3 x 3/ 1    104 x 104 x  64 ->  104 x 104 x 128 1.595 BF
  11 Shortcut Layer: 8
  12 conv    256       3 x 3/ 2    104 x 104 x 128 ->   52 x  52 x 256 1.595 BF
  13 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  14 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  15 Shortcut Layer: 12
  16 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  17 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  18 Shortcut Layer: 15
  19 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  20 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  21 Shortcut Layer: 18
  22 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  23 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  24 Shortcut Layer: 21
  25 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  26 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  27 Shortcut Layer: 24
  28 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  29 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  30 Shortcut Layer: 27
  31 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  32 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  33 Shortcut Layer: 30
  34 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  35 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  36 Shortcut Layer: 33
  37 conv    512       3 x 3/ 2     52 x  52 x 256 ->   26 x  26 x 512 1.595 BF
  38 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  39 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  40 Shortcut Layer: 37
  41 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  42 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  43 Shortcut Layer: 40
  44 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  45 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  46 Shortcut Layer: 43
  47 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  48 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  49 Shortcut Layer: 46
  50 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  51 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  52 Shortcut Layer: 49
  53 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  54 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  55 Shortcut Layer: 52
  56 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  57 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  58 Shortcut Layer: 55
  59 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  60 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  61 Shortcut Layer: 58
  62 conv   1024       3 x 3/ 2     26 x  26 x 512 ->   13 x  13 x1024 1.595 BF
  63 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  64 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  65 Shortcut Layer: 62
  66 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  67 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  68 Shortcut Layer: 65
  69 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  70 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  71 Shortcut Layer: 68
  72 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  73 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  74 Shortcut Layer: 71
  75 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  76 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  77 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  78 max                5x 5/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.002 BF
  79 route  77                                 ->   13 x  13 x 512 
  80 max                9x 9/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.007 BF
  81 route  77                                 ->   13 x  13 x 512 
  82 max               13x13/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.015 BF
  83 route  82 80 78 77                        ->   13 x  13 x2048 
# END SPP #
  84 conv    512       1 x 1/ 1     13 x  13 x2048 ->   13 x  13 x 512 0.354 BF
  85 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  86 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
# A(/32 Feature Map) #
  87 conv    256       1 x 1/ 1     13 x  13 x 512 ->   13 x  13 x 256 0.044 BF
  88 upsample                 2x    13 x  13 x 256 ->   26 x  26 x 256
# A -> B # 
  89 route  86                                 ->   13 x  13 x 512 
  90 conv    128       1 x 1/ 1     13 x  13 x 512 ->   13 x  13 x 128 0.022 BF
  91 upsample                 4x    13 x  13 x 128 ->   52 x  52 x 128
# A -> C #
  92 route  86                                 ->   13 x  13 x512
  93 conv    256       1 x 1/ 1     13 x  13 x512 ->   13 x  13 x 256 0.044 BF
  94 upsample                 2x    13 x  13 x 256 ->   26 x  26 x 256
  95 route  94 61                              ->   26 x  26 x 768 
  96 conv    256       1 x 1/ 1     26 x  26 x 768 ->   26 x  26 x 256 0.266 BF
  97 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  98 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  99 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
 100 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
# B(/16 Feature Map) #
 101 conv    512       3 x 3/ 2     26 x  26 x 256 ->   13 x  13 x 512 0.399 BF
# B -> A #
 102 route  100                                    ->   26 x  26 x 256 
 103 conv    128       1 x 1/ 1     26 x  26 x 256 ->   26 x  26 x 128 0.044 BF
 104 upsample                 2x    26 x  26 x 128 ->   52 x  52 x 128
# B -> C #
 105 route  100                                    ->   26 x  26 x 256 
 106 conv    128       1 x 1/ 1     26 x  26 x 256 ->   26 x  26 x 128 0.044 BF
 107 upsample                 2x    26 x  26 x 128 ->   52 x  52 x 128
 108 route  107 36                             ->   52 x  52 x 384 
 109 conv    128       1 x 1/ 1     52 x  52 x 384 ->   52 x  52 x 128 0.266 BF
 110 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
 111 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
 112 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
 113 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
# C(/8 Feature Map) #
 114 max                2x 2/ 2     52 x  52 x 128 ->   26 x  26 x 128 0.000 BF
 115 conv    512       3 x 3/ 2     26 x  26 x 128 ->   13 x  13 x 512 0.199 BF
# C -> A #
 116 route  113                                    ->   52 x  52 x 128 
 117 conv    256       3 x 3/ 2     52 x  52 x 128 ->   26 x  26 x 256 0.399 BF
# C -> B #
 118 route  86 101 115                             ->   13 x  13 x1536 
 119 conv      3       1 x 1/ 1     13 x  13 x1536 ->   13 x  13 x   3 0.002 BF
 120 route  119                                0/3 ->   13 x  13 x   1 
 121 scale Layer: 86
darknet: ./src/scale_channels_layer.c:23: make_scale_channels_layer: Assertion `l.out_c == l.c' failed.
Aborted (core dumped)

@AlexeyAB I created a asff.cfg based yolov3-spp.cfg, there is a error seems layer-86 is 13x13x512 and layer-119 e.g. alpha is 13x13x1, in [scale_channels] those layers output should be same?

AlexeyAB commented 4 years ago

@Kyuuki93 It seems I fixed it: https://github.com/AlexeyAB/darknet/commit/5ddf9c74a58ce61d2aa82b806b8d0912ab6cf8f3#diff-35a105a0ce468de87dbd554c901a45eeR23

Kyuuki93 commented 4 years ago

[route] layers=22,33,44 # 3-layers which are already resized to the same WxHxC

[convolutional] stride=1 size=1 filters=3 activation=normalize_channels # ReLU is integrated to activation=normalize_channels

[route] layers=-1 group_id=0 groups=3

[scale_channels] from=22 scale_wh=1

[route] layers=-3 group_id=1 groups=3

[scale_channels] from=33 scale_wh=1

[route] layers=-5 group_id=2 groups=3

[scale_channels] from=44 scale_wh=1

[shortcut] from=-3 activation=linear

[shortcut] from=-6 activation=linear

@AlexeyAB In your ASFF-like module, what exactly activation = normalize_channels do?

If activation = normalize_channels use relu to calculate gradients,
I think it should be activation = linear and use another softmax for (x1, x2, x3), to mach this formula alpha = exp(x1) / (exp(x1) + exp(x2) + exp(x3)), Or activation = softmax for SoftmaxBackward?

https://github.com/ruinmessi/ASFF/blob/f7814211b1fd1e6cde5e144503796f4676933667/models/network_blocks.py#L242

levels_weight = F.softmax(levels_weight, dim=1) levels_weight.shape was torch.Size([1,3,13,13])

Is 'activation = normalize_channels' same with this F.softmax ?

If activation = normalize_channels actually excuse this code, normalize_channels with relu function, negative value was removed, https://github.com/AlexeyAB/darknet/blob/9bb3c53698963f2a495be2dd9877d6ff523fe2ad/src/activations.c#L151-L177

maybe this result got a explain

Model chart cfg
spp,mse yolov3-spp-chart yolov3-spp.cfg.txt
spp,mse,asff chart yolov3-spp-asff.cfg.txt

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

And this ASFF module have a little different with your example, instead of

[route]
layers = 22,33,44# 3-layers which are already resized to the same WxHxC

...

use

[route]
layers = 22

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = 33

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = 44

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = -1,-3,-5

[convolutional]
stride=1
size=1
filters=3
activation= normalize_channels

...
AlexeyAB commented 4 years ago

@Kyuuki93

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

What do you mean?

And this ASFF module have a little different with your example, instead of

Why?


In your ASFF-like module, what exactly activation = normalize_channels do?

If activation = normalize_channels use relu to calculate gradients, I think it should be activation = linear and use another softmax for (x1, x2, x3), to mach this formula alpha = exp(x1) / (exp(x1) + exp(x2) + exp(x3)), Or activation = softmax for SoftmaxBackward?

There is in the normalize_channels implemented Fast Normalized Fusion that should have the same Accuracy but faster Speed than SoftMax across channels, that is used in BiFPN for EfficientDet: https://github.com/AlexeyAB/darknet/issues/4346

Later I will add activation=normalize_channels_softmax

image

Kyuuki93 commented 4 years ago

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

What do you mean?

Sorry, let me clear,

alpha(i,j) + beta(i,j) + gamma(i,j) = 1,
 and alpha(i,j)> 0, beta(i,j)>0, gamma(i,j)>0

In normalize_channels, maybe result from this code:

if (val > 0) val = val / sum; 
else val = 0;

many alpha or beta, gamma were set to 0, so relu gradients was 0 too, so gradients were vanished at very beginning, and in this way, training doesn't work properly, e.g. after 25k iters, best mAP@0.5 just 10.41%, in the training, the value of Obj: were very hard to increase.

And this ASFF module have a little different with your example, instead of

Why?

I checked author's model, layers 22,33,44 were never concat, I just implemented his network structure. In his model, the coefficients were calculate from layers 22,33,44 separately, and channels changes like

512 -> 8
512 -> 8  (cat to) 24 -> 3 
512 -> 8

instead of 512 -> 3

There is in the normalize_channels implemented Fast Normalized Fusion that should have the same Accuracy but faster Speed than SoftMax across channels, that is used in BiFPN for EfficientDet: #4346

I will try to find why BiFPN can work with relu style normalize_channels but ASFF can not, I have a thought, just let me check it out

Later I will add activation=normalize_channels_softmax

I will take another test then

AlexeyAB commented 4 years ago

@Kyuuki93

I checked author's model, layers 22,33,44 were never concat, I just implemented his network structure.

You have done right. I have not yet verified the entire cfg file as a whole.

Here we are not talking about layers with indices exactly 22, 33, 44. This is just an example. This means that already some layers with indicies XX,YY,ZZ are resized to the same WxHx8. It is assumed here that the layers are already applied: conv_stride_2, maxpool_sride_2, upsample_stride_2 and 4. And then applied conv-layer filters=8. And these 3 layers with size WxHx8 will be concatenated: https://github.com/ruinmessi/ASFF/blob/master/models/network_blocks.py#L240

That's how you did it.


In normalize_channels, maybe result from this code:

if (val > 0) val = val / sum; else val = 0; many alpha or beta, gamma were set to 0, so relu gradients was 0 too, so gradients were vanished at very beginning, and in this way, training doesn't work properly, e.g. after 25k iters, best mAP@0.5 just 10.41%, in the training, the value of Obj: were very hard to increase.

Yes, for one image - some outputs( alpha or beta, gamma) will have zeros, and for another image - other outputs( alpha or beta, gamma) will have zeros. There will not be dead neurons in Yolo, since all other layers use leaky-ReLU rather than ReLU.

This is a common problem for ReLU, calls dead neurons. https://datascience.stackexchange.com/questions/5706/what-is-the-dying-relu-problem-in-neural-networks This applies to all modern neural networks that use RELU: MobileNet v1, ResNet-101, ... The Leaky-ReLU, Swish or Mish solves this problem.

There will be dead neurons problem only if at least 2 conv-layers with ReLU in a row, go one after another. So output of conv-1 will be always >=0, so both input and output of conv-2 will be always >=0 In this case, since input of conv-2 is alwyas >=0, then if weights[i] < 0 then output of ReLU will be always 0 and Gradient will be always 0 - so there will be dead neurons, this weights[i]<0 will never be changed.

But if conv-1 layer has leak-ReLU (as in Yolo) or Swish or Mish activation, then input of conv-2 can be >0 or <0, then regardless of weights[i] (if weights[i] != 0) the Gradient will not be always == 0, and this weights[i]<0 will be changed sometime.

AlexeyAB commented 4 years ago

@Kyuuki93

Also you can try to use

[convolutional]
stride=1
size=1
filters=3
activation=logistic

instead of

[convolutional]
stride=1
size=1
filters=3
activation=normalize_channels
AlexeyAB commented 4 years ago

@Kyuuki93

I added [convolutional] activation=normalize_channels_softmax Check whether there are bugs: https://github.com/AlexeyAB/darknet/commit/c9c745ccf1de97d01cc3c69f81e83011f6439f1a and https://github.com/AlexeyAB/darknet/commit/4f52ba1a25ade35119cefc3840ef65a509851809


Page 4: https://arxiv.org/pdf/1911.09516v2.pdf

image

Kyuuki93 commented 4 years ago

Here we are not talking about layers with indices exactly 22, 33, 44. This is just an example.

Yes, I aware that layers 22 exactly is layers 86 in darknet's yolov3-spp.cfg and so on.

There will be dead neurons problem only if at least 2 conv-layers with ReLU in a row, go one after another. So output of conv-1 will be always >=0, so both input and output of conv-2 will be always >=0 In this case, since input of conv-2 is alwyas >=0, then if weights[i] < 0 then output of ReLU will be always 0 and Gradient will be always 0 - so there will be dead neurons, this weights[i]<0 will never be changed.

But if conv-1 layer has leak-ReLU (as in Yolo) or Swish or Mish activation, then input of conv-2 can be >0 or <0, then regardless of weights[i] (if weights[i] != 0) the Gradient will not be always == 0, and this weights[i]<0 will be changed sometime.

I see, so there are a little influence but should be work, wil try activation=logistic and activation=normalize_channels_softmax, result update later

Kyuuki93 commented 4 years ago

@AlexeyAB I created a new asff.cfg, yolov3-spp-asff-giou-logistic.cfg.txt but with normalize_channels_softmax, training loss goes to nan in ~100 iters, with logistic got this result, but yolov3-spp.cfg with mse loss achieved 89% at mAP@0.5 yolov3-spp-asff-giou-logistic-chart

If you got a time, could you tell me what mistake I made in there?

Kyuuki93 commented 4 years ago

If you got a time, could you tell me what mistake I made in there?

@AlexeyAB Sorry, that's my fault, previous .cfg file have connected wrong layers, This one should be right for ASFF module, yolov3-spp-asff.cfg.txt, viewed by Netron yolov3-spp-asff.png.zip

But unfortunately, this net was untrainingable, the official repo mentioned ASFF module need long warm up to avoid nan loss, but in darknet nan loss show up at the time lr > 0, no matter what activation = logistic or norm_channels or norm_channels_softmax, so I am wondering which part goes wrong.

Followed ASFF's native idea, e.g. every yolo-layer used all scale size feature map, so I created a simplified-ASFF module, it's just added feature map from layers-22,33,44 (by shortcut) instead of multiply with alpha, beta, gamma

and this was simplified one, yolov3-spp-asff-simplified.cfg.txt viewed by Netron yolov3-spp-asff-simplified.png.zip

yolov3-spp + Gaussian_yolo(iou_n,uc_n = 0.5) + iou_thresh=0.213 spp,giou,gs,it yolov3-spp + Gaussian_yolo(iou_n,uc_n = 0.5) + iou_thresh=0.213 + asff-sim chart

The complete training result will updated several hours later, but seems simplified-ASFF module could boost AP or at least increased training speed.

And about ASFF module, if the cfg file was no wrong, maybe these layers e.g. scale_channels, activation = norm_channels_* not work as my expect?

AlexeyAB commented 4 years ago

This one should be right for ASFF module, yolov3-spp-asff.cfg.txt,

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

like

[yolo]
mask = 0,1,2
#anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
anchors =  57, 64,  87,113, 146,110, 116,181, 184,157, 175,230, 270,196, 236,282, 322,319
classes=1
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
Kyuuki93 commented 4 years ago

This one should be right for ASFF module, yolov3-spp-asff.cfg.txt,

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

like

[yolo]
mask = 0,1,2
#anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
anchors =  57, 64,  87,113, 146,110, 116,181, 184,157, 175,230, 270,196, 236,282, 322,319
classes=1
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

I will try, and asff-sim results with gs, giou, iou_thresh: |baseline | AP@.5 = 91.89% |AP@.75 = 63.53%| |+asffsim| AP@.5 = 91.62% |AP@.75 = 63.28%|

results with mse loss will report tomorrow

AlexeyAB commented 4 years ago

@Kyuuki93 So assf-simplified doesn't improve accuracy.

Try with default [yolo]+mse without normalizers and if it doesn't work then try with default anchors.

Kyuuki93 commented 4 years ago

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

Yes, ASFF-SIM with default [yolo] decrease 0.48% AP@.5, and AP@.75

|baseline | AP@.5 = 89.52% |AP@.75 = 51.72%| |+asffsim| AP@.5 = 89.04%|AP@.75 = 51.24%|

AlexeyAB commented 4 years ago

@Kyuuki93

Try norm_channels or norm_channels_softmax with default [yolo] layers. May be only [Gaussian_yolo] produces Nan with ASFF.

Kyuuki93 commented 4 years ago

I tried, it’s same

AlexeyAB commented 4 years ago

@Kyuuki93

And about ASFF module, if the cfg file was no wrong, maybe these layers e.g. scale_channels, activation = normchannels* not work as my expect?

I checked the implementation of activation = norm_channels_* and didn't find bugs.

Also cfg file https://github.com/AlexeyAB/darknet/files/3939064/yolov3-spp-asff.cfg.txt is almost correct, except very low learning_rate and burn_in=0, you should use burn_in=4000 and higher learning_rate. Also use default [yolo] without normalizers.


  1. Do you get Nan or low mAP with activation=norm_channels ?

  2. set [net] burn_in=4000 in cfg-file

  3. Change this line: https://github.com/AlexeyAB/darknet/blob/dbe34d78658746fcfc9548ebab759895ea05a70c/src/blas_kernels.cu#L1153 to this atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index] / channel_size);

  4. Check that grad=0 there https://github.com/AlexeyAB/darknet/blob/dbe34d78658746fcfc9548ebab759895ea05a70c/src/activation_kernels.cu#L513

  5. use default [yolo] and iou_normalizer = 1.0 iou_loss = mse

  6. Recompile and train ASSF with default [yolo] and activation=norm_channels and activation=norm_channels_softmax and

    learning_rate=0.001
    burn_in=4000

Show output chart.png with Loss and mAP for both activation=norm_channels and activation=norm_channels_softmax

Kyuuki93 commented 4 years ago

Very low LR and burn_in=0 were set to test if nan will show when LR>0. I have tried usual LR and burn_in =2000, it’s same, no matter yolo or gs_yolo with any loss

I will try your suggestion tomorrow, it’s midnight here

Kyuuki93 commented 4 years ago

For 4. is already grad = 0, set baseline as

learning_rate = 0.0005 # for 2 GPUs 
burn_in = 4000
...
activation = normalize_channels_softmax # in the end of ASFF
...
[yolo]
iou_loss = mse
iou_normalizer = 1.0
cls_normalizer = 1.0
ignore_thresh = 0.213
...

Also, pre-trained weight used yolov3-spp.conv.86 instead of yolov3-spp.conv.88

Settings Got NaN Iters got NaN Chart
baseline y 363 -
activation -> normalize_channels n - chart

After add /channel_size to darknet/src/blas_kernels.cu#L1153

Settings Got NaN Iters got NaN Chart
baseline n - chart
activation -> normalize_channels

@AlexeyAB It's seems work fine for now, full result will update later, it will be delayed a few days because of a business trip

AlexeyAB commented 4 years ago

@Kyuuki93 Fine. What is base line in your table?

Kyuuki93 commented 4 years ago

@Kyuuki93 Fine. What is base line in your table?

yolov3-spp.cfg with mse loss, only add iou_thresh = 0.213

AlexeyAB commented 4 years ago

@Kyuuki93

This is strange:

yolov3-spp.cfg with mse loss, only add iou_thresh = 0.213

But why default yolov3-spp.cfg goes to Nan, while there are no ASFF, [scale_channels]-layer or activation=normalize_channels ? https://github.com/AlexeyAB/darknet/issues/4382#issuecomment-564437176

And why it doesn't go to Nan after fixing [scale_channels]-layer , while yolov3-spp.cfg doesn't have [scale_channels]-layer?

Kyuuki93 commented 4 years ago

This is strange:

yolov3-spp.cfg with mse loss, only add iou_thresh = 0.213

But why default yolov3-spp.cfg goes to Nan, while there are no ASFF, [scale_channels]-layer or activation=normalize_channels ? #4382 (comment)

And why it doesn't go to Nan after fixing [scale_channels]-layer , while yolov3-spp.cfg doesn't have [scale_channels]-layer?

Sorry, all the test with ASFF-module, so baseline is yolov3-spp.cfg with mse loss, it=0.213, asff

Kyuuki93 commented 4 years ago

There is baseline .cfg file yolov3-spp-asff.cfg.txt

AlexeyAB commented 4 years ago

Thanks for explanation. So baseline is = yolov3-spp + iou_loss = mse + iou_thresh = 0.213 + ASFF with activation=normalize_channels_softmax

AlexeyAB commented 4 years ago
Kyuuki93 commented 4 years ago

How about guiding anchor? In asff paper, yolo head consist of GA+RBF+Deform-conv

AlexeyAB commented 4 years ago

@Kyuuki93 Didn't look yet. Lets test these blocks first, to make sure they work and increase accuracy.

(A lot of features increase accuracy only in rare cases or if tricks and cheats are used.)

Kyuuki93 commented 4 years ago

@Kyuuki93 Didn't look yet. Lets test these blocks first, to make sure they work and increase accuracy.

(A lot of features increase accuracy only in rare cases or if tricks and cheats are used.)

Of course, I will test dropblock this days or after back to office, by the way, training with my 4xGPUs machine works well over 80k iters, but with 2xGPUs machine still got killed after 10k iters, so this updated chart will be divided

AlexeyAB commented 4 years ago

@Kyuuki93 What GPU, OS and OpenCV versions do you use for?


How to use RFB-block: https://github.com/AlexeyAB/darknet/issues/4507

Kyuuki93 commented 4 years ago

@Kyuuki93 What GPU, OS and OpenCV versions do you use for?

  • 4xGPUs
  • 2xGPUs

4 2080 Ti GPUs, Ubuntu 18.04, OpenCV 3.4.7 2 1080 Ti GPUs, Ubuntu 18.04, OpenCV 3.4.7

Kyuuki93 commented 4 years ago
Model Chart AP@.5 AP@.75 Inference Time (416x416)
spp,mse,it=0.213 spp,mse,it 92.01% 60.49% 13.75ms
spp,mse,it=0.213,asff(softmax) spp,mse,it,asff 92.45% 61.83% 15.44ms
spp,mse,it=0.213,asff(relu) 91.83% 59.60% 15.34ms
spp,mse,it=0.213,asff(logistic ) 91.18% 60.79% 15.40ms

Will complete this Table soon, so far, the ASFF module seems work well, it's AP already higher than spp,giou,gs,it which can see in https://github.com/AlexeyAB/darknet/issues/3874#issuecomment-561064425 @AlexeyAB

Will fulfill later

AlexeyAB commented 4 years ago

I added this fix: https://github.com/AlexeyAB/darknet/commit/d137d304c1410253894dbfb7abaadfe6f4f867e7

Compare which of the ASFFs is better: logistic, normalize_channels or normalize_channels_softmax

Now when we know that it works well, you also can try to test it with [Guassian_yolo]

Kyuuki93 commented 4 years ago

I added this fix: d137d30

Compare which of the ASFFs is better: logistic, normalize_channels or normalize_channels_softmax

Now when we know that it works well, you also can try to test it with [Guassian_yolo]

I fulfilled previous results table, ASFF module give a +0.44% at AP@.5 and +1.34% at AP@.75. I have added this to https://github.com/AlexeyAB/darknet/issues/3874#issuecomment-561064425

AlexeyAB commented 4 years ago

@Kyuuki93 Nice! Is it normalize_channels_softmax-ASFF?

What improvement in accuracy does give the normalize_channels(avg_ReLU)-ASFF?

Kyuuki93 commented 4 years ago

@Kyuuki93 Nice! Is it normalize_channels_softmax-ASFF?

What improvement in accuracy does give the normalize_channels(avg_ReLU)-ASFF?

Yes, is normalize_channels_softmax-ASFF, other normalize_channels didn't test yet, I will add dropblock before it

AlexeyAB commented 4 years ago

@Kyuuki93

This is weird that they use mlist.append(DropBlock(block_size=1, keep_prob=1)), https://github.com/ruinmessi/ASFF/blob/master/models/yolov3_asff.py#L39-L56


So I recommend to use

May be


I think I should split dropblock_size= to 2 params:

Kyuuki93 commented 4 years ago

This is weird that they use mlist.append(DropBlock(block_size=1, keep_prob=1)), https://github.com/ruinmessi/ASFF/blob/master/models/yolov3_asff.py#L39-L56

So I recommend to use

  • either DropOut + ASFF ([dropout] probability=0.1)
  • or DropBlock + ASFF + RFB-block
[dropout]
dropblock=1
dropblock_size=0.1  # 10% of width and height
probability=0.1     # this is drop probability = (1 - keep_probability)

Ok then, I will check that, but I just found it not so convenient to add DropBlock when I leave office. So I will test normlize_channels_* first, some results will come out tomorrow

And check here https://github.com/ruinmessi/ASFF/blob/538459e8a948c9cd70dbd8b66ee6017d20af77cc/main.py#L317-L337, block_size=1 , keep_prob=1 should just for baseline which means original yolo-head

AlexeyAB commented 4 years ago

@Kyuuki93 Thanks, I missed this code.

I fixed DropBlock: https://github.com/AlexeyAB/darknet/issues/4498 Just use (use dropblock_size_abs=7 if RFB is used, otherwise use dropblock_size_abs=2):

[dropout]
dropblock=1
dropblock_size_abs=7  # block size 7x7
probability=0.1       # this is drop probability = (1 - keep_probability)

It will work mostly the same as in ASFF implementation (gradually increasing the block size from 1x1 to 7x7 in the first half of the training time):

Also probability will increase from 0.0 to 0.1 in the first half of the training time - as in the original DropBlock paper.

Kyuuki93 commented 4 years ago

@AlexeyAB I have fulfilled the Table