ceccocats / tkDNN

Deep neural network library and toolkit to do high performace inference on NVIDIA jetson platforms
GNU General Public License v2.0
717 stars 209 forks source link

Yolov4tiny - 3 layer incompatibility #81

Closed marvision-ai closed 4 years ago

marvision-ai commented 4 years ago

I have been testing many versions of yolov4-tiny. Recently, alex released a yolov4-tiny-3l cfg. https://github.com/AlexeyAB/darknet/blob/de68e19cc627f642023f09513ac2306fbcbc1e4b/cfg/yolov4-tiny-3l.cfg

I have to switch the width= 1120 and height=960.

I have trained to great accuracy and would like to use this model. When I attempt to export I get the following output:

nvidia@nvidia:~/ai/tkDNN/build$ ./test_yolo4tiny-3l-camshaft 
Not supported field: batch=64
Not supported field: subdivisions=16
Not supported field: momentum=0.9
Not supported field: decay=0.0005
Not supported field: angle=0
Not supported field: saturation = 1.5
Not supported field: exposure = 1.5
Not supported field: hue=.1
Not supported field: mosaic=1
Not supported field: learning_rate=0.00261
Not supported field: burn_in=1000
Not supported field: max_batches = 10000
Not supported field: policy=steps
Not supported field: steps=8000,9000
Not supported field: scales=.1,.1
New NETWORK (tkDNN v0.5, CUDNN v7.603)
!! FP16 INFERENCE ENABLED !!
Reading weights: I=3 O=32 KERNEL=3x3x1
Reading weights: I=32 O=64 KERNEL=3x3x1
Reading weights: I=64 O=64 KERNEL=3x3x1
Reading weights: I=32 O=32 KERNEL=3x3x1
Reading weights: I=32 O=32 KERNEL=3x3x1
Reading weights: I=64 O=64 KERNEL=1x1x1
Reading weights: I=128 O=128 KERNEL=3x3x1
Reading weights: I=64 O=64 KERNEL=3x3x1
Reading weights: I=64 O=64 KERNEL=3x3x1
Reading weights: I=128 O=128 KERNEL=1x1x1
Reading weights: I=256 O=256 KERNEL=3x3x1
Reading weights: I=128 O=128 KERNEL=3x3x1
Reading weights: I=128 O=128 KERNEL=3x3x1
Reading weights: I=256 O=256 KERNEL=1x1x1
Reading weights: I=512 O=512 KERNEL=3x3x1
Reading weights: I=512 O=256 KERNEL=1x1x1
Reading weights: I=256 O=512 KERNEL=3x3x1
Reading weights: I=512 O=27 KERNEL=1x1x1
Not supported field: anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
Not supported field: jitter=.3
Not supported field: cls_normalizer=1.0
Not supported field: iou_normalizer=0.07
Not supported field: iou_loss=ciou
Not supported field: ignore_thresh = .7
Not supported field: truth_thresh = 1
Not supported field: random=1
Not supported field: nms_kind=greedynms
Not supported field: beta_nms=0.6
Reading weights: I=256 O=128 KERNEL=1x1x1
Reading weights: I=384 O=256 KERNEL=3x3x1
Reading weights: I=256 O=27 KERNEL=1x1x1
Not supported field: anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
Not supported field: jitter=.3
Not supported field: cls_normalizer=1.0
Not supported field: iou_normalizer=0.07
Not supported field: iou_loss=ciou
Not supported field: ignore_thresh = .7
Not supported field: truth_thresh = 1
Not supported field: random=1
Not supported field: nms_kind=greedynms
Not supported field: beta_nms=0.6
Reading weights: I=256 O=64 KERNEL=1x1x1
Reading weights: I=192 O=128 KERNEL=3x3x1
Reading weights: I=128 O=27 KERNEL=1x1x1
Not supported field: anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
Not supported field: jitter=.3
Not supported field: cls_normalizer=1.0
Not supported field: iou_normalizer=0.07
Not supported field: iou_loss=ciou
Not supported field: ignore_thresh = .7
Not supported field: truth_thresh = 1
Not supported field: random=1
Not supported field: nms_kind=greedynms
Not supported field: beta_nms=0.6

====================== NETWORK MODEL ======================
N.  Layer type       input (H*W,CH)        output (H*W,CH) 
  0 Conv2d           960 x 1120,    3  ->  480 x  560,   32
  1 ActivationLeaky  480 x  560,   32  ->  480 x  560,   32
  2 Conv2d           480 x  560,   32  ->  240 x  280,   64
  3 ActivationLeaky  240 x  280,   64  ->  240 x  280,   64
  4 Conv2d           240 x  280,   64  ->  240 x  280,   64
  5 ActivationLeaky  240 x  280,   64  ->  240 x  280,   64
  6 Route            240 x  280,   32  ->  240 x  280,   32
  7 Conv2d           240 x  280,   32  ->  240 x  280,   32
  8 ActivationLeaky  240 x  280,   32  ->  240 x  280,   32
  9 Conv2d           240 x  280,   32  ->  240 x  280,   32
 10 ActivationLeaky  240 x  280,   32  ->  240 x  280,   32
 11 Route            240 x  280,   64  ->  240 x  280,   64
 12 Conv2d           240 x  280,   64  ->  240 x  280,   64
 13 ActivationLeaky  240 x  280,   64  ->  240 x  280,   64
 14 Route            240 x  280,  128  ->  240 x  280,  128
 15 Pooling          240 x  280,  128  ->  120 x  140,  128
 16 Conv2d           120 x  140,  128  ->  120 x  140,  128
 17 ActivationLeaky  120 x  140,  128  ->  120 x  140,  128
 18 Route            120 x  140,   64  ->  120 x  140,   64
 19 Conv2d           120 x  140,   64  ->  120 x  140,   64
 20 ActivationLeaky  120 x  140,   64  ->  120 x  140,   64
 21 Conv2d           120 x  140,   64  ->  120 x  140,   64
 22 ActivationLeaky  120 x  140,   64  ->  120 x  140,   64
 23 Route            120 x  140,  128  ->  120 x  140,  128
 24 Conv2d           120 x  140,  128  ->  120 x  140,  128
 25 ActivationLeaky  120 x  140,  128  ->  120 x  140,  128
 26 Route            120 x  140,  256  ->  120 x  140,  256
 27 Pooling          120 x  140,  256  ->   60 x   70,  256
 28 Conv2d            60 x   70,  256  ->   60 x   70,  256
 29 ActivationLeaky   60 x   70,  256  ->   60 x   70,  256
 30 Route             60 x   70,  128  ->   60 x   70,  128
 31 Conv2d            60 x   70,  128  ->   60 x   70,  128
 32 ActivationLeaky   60 x   70,  128  ->   60 x   70,  128
 33 Conv2d            60 x   70,  128  ->   60 x   70,  128
 34 ActivationLeaky   60 x   70,  128  ->   60 x   70,  128
 35 Route             60 x   70,  256  ->   60 x   70,  256
 36 Conv2d            60 x   70,  256  ->   60 x   70,  256
 37 ActivationLeaky   60 x   70,  256  ->   60 x   70,  256
 38 Route             60 x   70,  512  ->   60 x   70,  512
 39 Pooling           60 x   70,  512  ->   30 x   35,  512
 40 Conv2d            30 x   35,  512  ->   30 x   35,  512
 41 ActivationLeaky   30 x   35,  512  ->   30 x   35,  512
 42 Conv2d            30 x   35,  512  ->   30 x   35,  256
 43 ActivationLeaky   30 x   35,  256  ->   30 x   35,  256
 44 Conv2d            30 x   35,  256  ->   30 x   35,  512
 45 ActivationLeaky   30 x   35,  512  ->   30 x   35,  512
 46 Conv2d            30 x   35,  512  ->   30 x   35,   27
 47 Yolo              30 x   35,   27  ->   30 x   35,   27
 48 Route             30 x   35,  256  ->   30 x   35,  256
 49 Conv2d            30 x   35,  256  ->   30 x   35,  128
 50 ActivationLeaky   30 x   35,  128  ->   30 x   35,  128
 51 Upsample          30 x   35,  128  ->   60 x   70,  128
 52 Route             60 x   70,  384  ->   60 x   70,  384
 53 Conv2d            60 x   70,  384  ->   60 x   70,  256
 54 ActivationLeaky   60 x   70,  256  ->   60 x   70,  256
 55 Conv2d            60 x   70,  256  ->   60 x   70,   27
 56 Yolo              60 x   70,   27  ->   60 x   70,   27
 57 Route             60 x   70,  256  ->   60 x   70,  256
 58 Conv2d            60 x   70,  256  ->   60 x   70,   64
 59 ActivationLeaky   60 x   70,   64  ->   60 x   70,   64
 60 Upsample          60 x   70,   64  ->  120 x  140,   64
 61 Route            120 x  140,  192  ->  120 x  140,  192
 62 Conv2d           120 x  140,  192  ->  120 x  140,  128
 63 ActivationLeaky  120 x  140,  128  ->  120 x  140,  128
 64 Conv2d           120 x  140,  128  ->  120 x  140,   27
 65 Yolo             120 x  140,   27  ->  120 x  140,   27
===========================================================

GPU free memory: 18142.9 mb.
New NetworkRT (TensorRT v6.01)
Float16 support: 1
Int8 support: 1
DLAs: 2
Selected maxBatchSize: 1
GPU free memory: 17834.7 mb.
Building tensorRT cuda engine...
serialize net
create execution context
Input/outputs numbers: 4
input idex = 0 -> output index = 3
Data dim: 1 3 960 1120 1
Data dim: 1 27 120 140 1
RtBuffer 0   dim: Data dim: 1 3 960 1120 1
RtBuffer 1   dim: Data dim: 1 27 30 35 1
RtBuffer 2   dim: Data dim: 1 27 60 70 1
RtBuffer 3   dim: Data dim: 1 27 120 140 1
2 3
outputs size missmatch
/home/nvidia/ai/tkDNN/include/tkDNN/test.h:23
Aborting...

Any help would be greatly appreciated! Congrats on the great repo!

mive93 commented 4 years ago

Hi @marvision-ai

can you show us the cpp of the test. Probably the problem is just with outputbins to check the correctness of the results.

This is the code for yolov4tiny (with 2 yolo layers)

std::vector<std::string> output_bins = {
        bin_path + "/debug/layer30_out.bin",
        bin_path + "/debug/layer37_out.bin"
    };

For the model you are using three are required.

marvision-ai commented 4 years ago

Hi @mive93 ,

Yes you are correct! That makes sense. Here is the current one I am using.

#include<iostream>
#include<vector>
#include "tkdnn.h"
#include "test.h"
#include "DarknetParser.h"

int main() {
    std::string bin_path  = "/home/nvidia/ai/tkDNN/build/yolo4tiny-3l-shaft";
    std::vector<std::string> input_bins = { 
        bin_path + "/layers/input.bin"
    };
    std::vector<std::string> output_bins = {
        bin_path + "/debug/layer30_out.bin",
        bin_path + "/debug/layer37_out.bin"
    };
    std::string wgs_path  = bin_path + "/layers";
    std::string cfg_path  = std::string(TKDNN_PATH) + "/tests/darknet/cfg/yolov4-tiny-shaft-3l-rnd.cfg";
    std::string name_path = std::string(TKDNN_PATH) + "/tests/darknet/names/shaft.names";

    // parse darknet network
    tk::dnn::Network *net = tk::dnn::darknetParser(cfg_path, wgs_path, name_path);
    net->print();

    //convert network to tensorRT
    tk::dnn::NetworkRT *netRT = new tk::dnn::NetworkRT(net, net->getNetworkRTName(bin_path.c_str()));

    int ret = testInference(input_bins, output_bins, net, netRT);
    net->releaseLayers();
    delete net;
    delete netRT;
    return ret;
}

I will update this to include layer 44 and report back my results.

Thank you!

marvision-ai commented 4 years ago

@mive93 I have fixed the issue.

    std::vector<std::string> output_bins = {
        bin_path + "/debug/layer30_out.bin",
        bin_path + "/debug/layer37_out.bin",
        bin_path + "/debug/layer44_out.bin"
    };

Compiles and runs nicely. Thanks for the heads up.

grandprixgp commented 3 years ago

@marvision-ai I see you are still activate so I hope you don't mind me bumping this. Trying to convert a yolov4-tiny-3l model to be used under tkDNN. Weights exported and moving onto conversion to TensorRT, I run into this issue:

....

GPU free memory: 8502.9 mb.
New NetworkRT (TensorRT v7.23)
Float16 support: 1
Int8 support: 1
DLAs: 0
create execution context
Input/outputs numbers: 3
input index = 0 -> output index = 2
Data dim: 1 3 416 416 1
Data dim: 1 255 26 26 1
RtBuffer 0   dim: Data dim: 1 3 416 416 1
RtBuffer 1   dim: Data dim: 1 255 13 13 1
RtBuffer 2   dim: Data dim: 1 255 26 26 1

====== CUDNN inference ======
Data dim: 1 3 320 320 1
new_coords0
new_coords0
new_coords0
new_coords0
new_coords0
new_coords0
new_coords0
new_coords0
new_coords0
Data dim: 1 18 40 40 1

===== TENSORRT inference ====
Data dim: 1 3 320 320 1
Cuda failure: invalid argument
C:\Users\admin\Source\Repos\tkDNN\src\NetworkRT.cpp:205

I've made some modifications to get to this point, including the one you posted, and I thought this had something to do with batchsize but that doesn't seem to be the case.. any ideas?

perseusdg commented 3 years ago

can you comment more on your msvc,cuda and nvidia driver versions?

grandprixgp commented 3 years ago

MSVC: 19.28.29336 CUDA: 11.2 NVIDIA: 460.90 ARCH: 86 (Ampere) (3080)

I see the mention of using 465+, so I'll update that now and get back to you.

@perseusdg Edit: No change after upgrading to 466.77.

grandprixgp commented 3 years ago

https://github.com/ceccocats/tkDNN/blob/c306b368608893e92925bf143e7cf14f19525aeb/src/NetworkRT.cpp#L205

I wonder if the network size is being miscalculated, it seems to default to 416x416 - our custom model uses 320*320. Is it possible that the project expects dimensions in specific intervals? After validating each argument, the size is my biggest suspect right now.

grandprixgp commented 3 years ago

I managed to get the TensorRT net built and serialized. Unfortunately now I don't seem to have any detections similar to this issue https://github.com/ceccocats/tkDNN/issues/228

grandprixgp commented 3 years ago
=== OUTPUT 0 CHECK RESULTS ==
CUDNN vs correct
 | [ 0 ]: nan 0.425675
 | [ 1 ]: nan 0.304462
 | [ 2 ]: nan 0.121382
 | [ 3 ]: nan 0.0980753
 | [ 4 ]: nan 0.113222
 | [ 5 ]: nan 0.0964858
 | [ 6 ]: nan 0.149539
 | [ 7 ]: nan 0.184255
 | [ 8 ]: nan 0.117385
 | Wrongs: 1800 ~0.02
TRT   vs correct
 | [ 0 ]: nan 0.425675
 | [ 1 ]: nan 0.304462
 | [ 2 ]: nan 0.121382
 | [ 3 ]: nan 0.0980753
 | [ 4 ]: nan 0.113222
 | [ 5 ]: nan 0.0964858
 | [ 6 ]: nan 0.149539
 | [ 7 ]: nan 0.184255
 | [ 8 ]: nan 0.117385
 | Wrongs: 1800 ~0.02
CUDNN vs TRT
 | [ 0 ]: nan nan
 | [ 1 ]: nan nan
 | [ 2 ]: nan nan
 | [ 3 ]: nan nan
 | [ 4 ]: nan nan
 | [ 5 ]: nan nan
 | [ 6 ]: nan nan
 | [ 7 ]: nan nan
 | [ 8 ]: nan nan
 | Wrongs: 1800 ~0.02

=== OUTPUT 1 CHECK RESULTS ==
CUDNN vs correct
 | [ 0 ]: nan 0.0156295
 | [ 1 ]: nan 0.0647845
 | [ 2 ]: nan 0.02962
 | [ 3 ]: nan 0.0509262
 | [ 4 ]: nan 0.088236
 | [ 5 ]: nan 0.0784189
 | [ 6 ]: nan 0.0819883
 | [ 7 ]: nan 0.0894271
 | [ 8 ]: nan 0.0896965
 | Wrongs: 7200 ~0.02
TRT   vs correct
 | [ 0 ]: nan 0.0156295
 | [ 1 ]: nan 0.0647845
 | [ 2 ]: nan 0.02962
 | [ 3 ]: nan 0.0509262
 | [ 4 ]: nan 0.088236
 | [ 5 ]: nan 0.0784189
 | [ 6 ]: nan 0.0819883
 | [ 7 ]: nan 0.0894271
 | [ 8 ]: nan 0.0896965
 | Wrongs: 7200 ~0.02
CUDNN vs TRT
 | [ 0 ]: nan nan
 | [ 1 ]: nan nan
 | [ 2 ]: nan nan
 | [ 3 ]: nan nan
 | [ 4 ]: nan nan
 | [ 5 ]: nan nan
 | [ 6 ]: nan nan
 | [ 7 ]: nan nan
 | [ 8 ]: nan nan
 | Wrongs: 7200 ~0.02

=== OUTPUT 2 CHECK RESULTS ==
CUDNN vs correct
 | [ 0 ]: nan 0.244062
 | [ 1 ]: nan 0.750405
 | [ 2 ]: nan 0.826826
 | [ 3 ]: nan 0.665474
 | [ 4 ]: nan 0.499885
 | [ 5 ]: nan 0.402975
 | [ 6 ]: nan 0.391236
 | [ 7 ]: nan 0.352832
 | [ 8 ]: nan 0.354092
 | Wrongs: 28800 ~0.02
TRT   vs correct
 | [ 0 ]: nan 0.244062
 | [ 1 ]: nan 0.750405
 | [ 2 ]: nan 0.826826
 | [ 3 ]: nan 0.665474
 | [ 4 ]: nan 0.499885
 | [ 5 ]: nan 0.402975
 | [ 6 ]: nan 0.391236
 | [ 7 ]: nan 0.352832
 | [ 8 ]: nan 0.354092
 | Wrongs: 28800 ~0.02
CUDNN vs TRT
 | [ 0 ]: nan nan
 | [ 1 ]: nan nan
 | [ 2 ]: nan nan
 | [ 3 ]: nan nan
 | [ 4 ]: nan nan
 | [ 5 ]: nan nan
 | [ 6 ]: nan nan
 | [ 7 ]: nan nan
 | [ 8 ]: nan nan
 | Wrongs: 28800 ~0.02

That explains it, unfortunately I'm not sure how to proceed from here.

Additional findings:

The above output is with TKDNN_MODE equal to FP32.

With TKDNN_MODE equal to FP16 the result is a failure to generate TensorRT at https://github.com/ceccocats/tkDNN/blob/c306b368608893e92925bf143e7cf14f19525aeb/src/NetworkRT.cpp#L146.

With TKDNN_MODE equal to INT8 the result is a failure to generate TensorRT at https://github.com/ceccocats/tkDNN/blob/c306b368608893e92925bf143e7cf14f19525aeb/src/Int8BatchStream.cpp#L69.

Additionally, and I'm not quite sure what I did to change this behavior, but I'm not longer receiving zero detections, but rather consistently 6300 detections on a validation image, all of them invalid/null. image