dlibml / darknet

Darknet related stuff: ports of YOLO models to dlib
Boost Software License 1.0
5 stars 1 forks source link

Feature suggestions #3

Open pfeatherstone opened 3 years ago

pfeatherstone commented 3 years ago

Just putting some suggestions out there. Maybe we could organize these into projects.

arrufat commented 3 years ago

Yes, I want to do all those things at some point :)

I've already started working on YOLO scaled models, in particular YOLOv4x-mish, which can be found here.

My idea is to make a generic template class for yolo models, where the template type is the yolo model itself, then put each one in a separate unit, so that we can link to them (this greatly improves compilation time).

I've also started working on the to_label() part for yolo models, but it's not ready yet (lack of time these days). I will definitely push it unless someone wants to work on it. My dream would be to have the training part working, as well.

I am not sure how we can improve the performance. In my tests, dlib performs faster than pytorch for densenet, resnet and vovnet architectures and uses less memory (for small batch sizes, up to 4 or 8), but the tendency inverts for big batch sizes. If that is the case, dlib should be fast on single inference, so I am wondering if it's the post-processing (NMS, etc) that drags the peformance down... I want to test that at some point, as well.

pfeatherstone commented 3 years ago

When i was doing my tests, both onnx inference and dlib inference were doing NMS stuff. The NMS stuff is practically instantaneous (i haven't properly measured it though). I think it's something else that's causing bottlenecks. But your benchmarks are very interesting, and not what i was expecting based on my tests with yolov3. Properly profiling this at some point will be useful.

arrufat commented 3 years ago

Did you set CUDA_LAUNCH_BLOCKING to 1? I used that in all my benchmarks.

pfeatherstone commented 3 years ago

So when running all the yolo models with this repository, do you get similar performance to darknet and pytorch?

pfeatherstone commented 3 years ago

Did you set CUDA_LAUNCH_BLOCKING to 1? I used that in all my benchmarks.

No i've never used that. I imagine that would slow pytorch down.

arrufat commented 3 years ago

I added it because the creator of PyTorch suggested to, for proper benchmarking. https://github.com/arrufat/dlib-pytorch-benchmark/pull/2

pfeatherstone commented 3 years ago

Oh ok. Fair enough. At the end of the day, running yolov3 with onnxruntime, pytorch or darknet yields roughly 65 FPS on 416x416 images. With dlib, i think i got around 45 FPS. If we can close that gap, that would be great.

arrufat commented 3 years ago

Oh ok. Fair enough. At the end of the day, running yolov3 with onnxruntime, pytorch or darknet yields roughly 65 FPS on 416x416 images. With dlib, i think i got around 45 FPS. If we can close that gap, that would be great.

I've just run yolov3 on darknet and dlib and I can confirm similar numbers, more precisely, on an NVIDIA GeForce GTX 1080 Ti:

FPS FPS VRAM (MiB) VRAM(MiB)
model (size) darknet dlib darknet dlib
yolov3 (416) 70 50 865 835
yolov4 (608) 32 22 1545 1742

I agree, it'd be cool if we could find out and fix the bottlenecks.

pfeatherstone commented 3 years ago

That's great thanks. Whenever i have time i will have a look. A decent profiler will go a long way. I always struggle to interpret sysprof. I'll try orbit again at some point. Though i had trouble building it last time i seem to remember.

pfeatherstone commented 3 years ago

It could be tensor.host() is called in a few places which introduce unnecessary barriers. I don't know enough about CUDA to be honest.

arrufat commented 3 years ago

As far as I know, tensor.host() is only called once in user code (to get the actual output of the network). I need to check if it's called somewhere else inside some layer implementation...

pfeatherstone commented 3 years ago

Yep building google's orbit profiler failed again. I'll have to do this at some point in my free time. Thanks @arrufat for investigating.

davisking commented 3 years ago

I'm not sure what causes this, but there shouldn't be any unnecessary tensor.host() calls. When the network is running it should all stay on the GPU. My guess is that darknet is making use of the fused conv+relu methods in cuDNN. dlib doesn't do that yet, it's still running those as 2 calls to cuDNN rather than one, which is a modest but non-trivial difference in speed if darknet is doing it like that.

On Mon, Jan 18, 2021 at 4:12 AM pfeatherstone notifications@github.com wrote:

Yep building google's orbit profiler failed again. I'll have to do this at some point in my free time. Thanks @arrufat https://github.com/arrufat for investigating.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/dlib-users/darknet/issues/3#issuecomment-762104542, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABPYFRYIEWZIZISAT7UR74TS2P3RNANCNFSM4V2JGZ7A .

pfeatherstone commented 3 years ago

@davisking Does dlib do fused convolution and batch normalisation ?

pfeatherstone commented 3 years ago

Darknet definitely does that. But then again, Pytorch doesn't and it achieves similar FPS to darknet, if not faster. I've seen onnxruntime achieve even faster FPS, but it does all sorts of crazy shit with graph optimization.

pfeatherstone commented 3 years ago

If dlib doesn't implement fused conv-batchnorm, maybe that could be implemented as a layer visitor when doing inference, which updates the convolutional filters and biases, and nulls the affine layers.

pfeatherstone commented 3 years ago

Here is Alexey's code for fused conv-batchnorm:

void fuse_conv_batchnorm(network net)
{
    int j;
    for (j = 0; j < net.n; ++j) {
        layer *l = &net.layers[j];

        if (l->type == CONVOLUTIONAL) {
            //printf(" Merges Convolutional-%d and batch_norm \n", j);

            if (l->share_layer != NULL) {
                l->batch_normalize = 0;
            }

            if (l->batch_normalize) {
                int f;
                for (f = 0; f < l->n; ++f)
                {
                    l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f] + .00001));

                    double precomputed = l->scales[f] / (sqrt((double)l->rolling_variance[f] + .00001));

                    const size_t filter_size = l->size*l->size*l->c / l->groups;
                    int i;
                    for (i = 0; i < filter_size; ++i) {
                        int w_index = f*filter_size + i;

                        l->weights[w_index] *= precomputed;
                    }
                }

                free_convolutional_batchnorm(l);
                l->batch_normalize = 0;
#ifdef GPU
                if (gpu_index >= 0) {
                    push_convolutional_layer(*l);
                }
#endif
            }
        }
        else  if (l->type == SHORTCUT && l->weights && l->weights_normalization)
        {
            if (l->nweights > 0) {
                //cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
                int i;
                for (i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]);
                printf(" l->nweights = %d, j = %d \n", l->nweights, j);
            }

            // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
            const int layer_step = l->nweights / (l->n + 1);    // 1 or l.c or (l.c * l.h * l.w)

            int chan, i;
            for (chan = 0; chan < layer_step; ++chan)
            {
                float sum = 1, max_val = -FLT_MAX;

                if (l->weights_normalization == SOFTMAX_NORMALIZATION) {
                    for (i = 0; i < (l->n + 1); ++i) {
                        int w_index = chan + i * layer_step;
                        float w = l->weights[w_index];
                        if (max_val < w) max_val = w;
                    }
                }

                const float eps = 0.0001;
                sum = eps;

                for (i = 0; i < (l->n + 1); ++i) {
                    int w_index = chan + i * layer_step;
                    float w = l->weights[w_index];
                    if (l->weights_normalization == RELU_NORMALIZATION) sum += lrelu(w);
                    else if (l->weights_normalization == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
                }

                for (i = 0; i < (l->n + 1); ++i) {
                    int w_index = chan + i * layer_step;
                    float w = l->weights[w_index];
                    if (l->weights_normalization == RELU_NORMALIZATION) w = lrelu(w) / sum;
                    else if (l->weights_normalization == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
                    l->weights[w_index] = w;
                }
            }

            l->weights_normalization = NO_NORMALIZATION;

#ifdef GPU
            if (gpu_index >= 0) {
                push_shortcut_layer(*l);
            }
#endif
        }
        else {
            //printf(" Fusion skip layer type: %d \n", l->type);
        }
    }
}

So a dlib visitor and a bit of tensor manipulation. Shouldn't be too bad.

davisking commented 3 years ago

@davisking Does dlib do fused convolution and batch normalisation ?

No. Need to have new layers for that.

pfeatherstone commented 3 years ago

won't a layer visitor do the job?

davisking commented 3 years ago

won't a layer visitor do the job?

Yeah or that with appropriate updates to the code.

pfeatherstone commented 3 years ago

@arrufat can you run your benchmark again but disable fuse_conv_batchnorm in darknet. i've done a quick grep, and i think all you need to do is uncomment the following lines: line 2251 in parser.c line 162 in demo.c line 1617 in detector.c If you're using darknet demo ..., then you won't need to do the last one. I would do it myself, but the benchmark is only meaningful if it's done on the same machine in the same "conditions" If the FPS is still around 70, then we know it's not fuse_conv_batchnorm that's causing the performance boost.

arrufat commented 3 years ago

@pfeatherstone after doing what you suggested, I go from 70 fps to 60 fps, so there's some room for improvement there :)

pfeatherstone commented 3 years ago

That's promising. But it it might still suggest it's something else causing bottlenecks. I'll try adding the visitor this weekend. It shouldn't take too long. It also requires adding "bypass" functionality in affine_ layer.

arrufat commented 3 years ago

Could it be possible to make a new layer similar to affine_ than can be constructed from a bn_ but that behaves like a tag (i.e. it just forwards the input to its output without any runtime cost)?

Then we could assign a network defined with bn_ layers to a network defined with this "bypass" layers, in the same way it's done for affine_.

pfeatherstone commented 3 years ago

Can do. I have no strong opinions. I imagine there would be no runtime cost if there was a flag due to branch prediction always guessing correctly (presumably there would be an if-statement around the flag. If true, simply forward input to output). Honestly it makes no difference to me. Whichever is the most expressive. Your way requires a new type, which means the whole network is a new type, which means the compiler has to compiler yet another gigantic new type, which means i have to wait another 15 minutes for clang to build yolov3. But at this stage, +- 15 minutes for building networks in dlib isn't a biggy.

arrufat commented 3 years ago

Yes, I agree, compile-times are getting a bit out of hand for big YOLO models (such as the recently published improvements to YOLOv4.) Maybe having an extra branch in each bn_ layer it doesn't affect the performance...

Regarding the compile times, that's why I build each model as a separate library and then link to it, so I don't have to rebuild it every time I change the code somewhere else)

https://github.com/dlib-users/darknet/blob/b78eddc08a7f5520103b2b296067a3516f5f7faa/CMakeLists.txt#L74-L77\

Here are the sizes of the compiled models, yolov3 is really tiny compared to the latest yolov4x_mish...

-rw-r--r--  1 adria 1.9M Jan 18 23:43 libyolov3.a
-rw-r--r--  1 adria 4.8M Jan 18 23:43 libyolov4.a
-rw-r--r--  1 adria 5.2M Jan 18 23:43 libyolov4_sam_mish.a
-rw-r--r--  1 adria  15M Jan 18 23:45 libyolov4x_mish.a
pfeatherstone commented 3 years ago

If I had a couple months of free time, I would roll up my sleeves and propose a new functional API to dnns in dlib, using dynamic polymorphism instead of static polymorphism for neural networks. I think that would solve a lot of frustrations, including compile times. I can see the benefits of using templates, it means you expedite optimisations to the compiler, but with large models, as you said, it gets out of hand. Having a functional API similar to pytorch for example would make dnns more accessible in dlib i think. But this would require a huge amount of time to get it right.

pfeatherstone commented 3 years ago

But that would require a lot of work on the tensor type too i think. So this wouldn't be an easy thing to do.

pfeatherstone commented 3 years ago

Yes, I agree, compile-times are getting a bit out of hand for big YOLO models (such as the recently published improvements to YOLOv4.) Maybe having an extra branch in each bn_ layer it doesn't affect the performance...

Regarding the compile times, that's why I build each model as a separate library and then link to it, so I don't have to rebuild it every time I change the code somewhere else)

https://github.com/dlib-users/darknet/blob/b78eddc08a7f5520103b2b296067a3516f5f7faa/CMakeLists.txt#L74-L77\

Here are the sizes of the compiled models, yolov3 is really tiny compared to the latest yolov4x_mish...

-rw-r--r--  1 adria 1.9M Jan 18 23:43 libyolov3.a
-rw-r--r--  1 adria 4.8M Jan 18 23:43 libyolov4.a
-rw-r--r--  1 adria 5.2M Jan 18 23:43 libyolov4_sam_mish.a
-rw-r--r--  1 adria  15M Jan 18 23:45 libyolov4x_mish.a

It's still impressive that a single model is compiled to nearly 2MB of binary. Maybe the bottlenecks are caused by code bloating? I don't know. I've never properly looked at the effects of binary size on performance.

arrufat commented 3 years ago

Honestly I really like the declarative way of defining networks in dlib, even if it requires some work to add new layers, it's worth it because:

Other than the compile times, I think dlib's approach to neural nets is the best (but I might be biased :P)

EDIT: also, if at some point dlib is ported to C++20, we could use concepts to get better error messages when we make a mistake in the network definition, that would be awesome.

pfeatherstone commented 3 years ago

We could achieve better error message using SFINAE. I've found using std::void_t is great utility for doing that. What i love about dlib is that it only requires C++11, which makes it highly portable. If you bump that up to C++20, that restricts a lot of places where it can run. And i agree with most of what you said above. Compile time checks is useful. But getting the tensor shapes right in pytorch isn't a huge effort and you only have to do it once. It also forces you to think about what's going on, which i like a lot. But developing a large model in dlib is a massive pain at the moment due to compilation. And I don't think C++ is any closer to reducing compilation times for templates. However, with dlib, your model gets compiled to a single type, which if you turn optimisations on when compiling, has all sorts of nice properties like model obfuscation. So i do love dlib's dnn stuff, but i don't think it lends itself nicely to research. I use it more for production and inference only. Which is a shame. It could be a first class research tool.

arrufat commented 3 years ago

@pfeatherstone you might have something to say about this https://github.com/davisking/dlib/pull/2294

pfeatherstone commented 3 years ago

This looks interesting. Alexey from darknet claims a 20% performance enhancement for detection models. Though users are reporting it's buggy (in darknet), and only provides 1% improvement. It could be a problem with the implementation in darknet, but worth knowing about i think.

https://developer.nvidia.com/blog/cuda-graphs/

arrufat commented 3 years ago

@pfeatherstone @davisking

So, I've added a simple implementation of the YOLO loss in the loss_yolo branch. It still does not work, but I think it's really close. I've been tinkering several days with it and couldn't make it work.

The implementation draws inspiration from:

If you have some time and point out where I made a mistake, I would be really grateful. Once it's working, I will clean it up and prepare it for a PR in the official dlib repo. I think the current interface is quite nice and the implementation easy to follow.

Thanks in advance.

arrufat commented 3 years ago

In case you want to try it, these are the voc2012 and coco2017 datasets converted to dlib's imglab format.

pfeatherstone commented 3 years ago

I can have a peek at some point. I would suggest that the bbox loss function should be one of GIOU, DIOU or CIOU loss functions. I've used these in pytorch and they converge much faster than L1 loss. So if we're going to give this to dlib users, may as well make it attractive and behave like a "state-of-the-art" bounding box regressor.

arrufat commented 3 years ago

Yes, I want to add that at some point, but the smoothed L1 loss was easier to implement... I want to have something that works and then tune it.

Also the BBR loss is an implementation detail, i think it will not change the interface, so it can be modified later.

pfeatherstone commented 3 years ago

There are going to be some gotchas though. When i've trained yolo models (not just on image data by the way...) i've had to use a burnin' scheduler which very slowly increases learning rate for the first X batches (e.g. 500 batches), apply tricks like increase the learning rate of biases, use tuned hyper-parameters for the different loss functions, weight decay and so on. I don't know if any of these little adjustments fit in the dlib framework, or if they do, how much do they bloat the training code. I still wonder if the best thing to do for running yolo models in dlib, is to train in darknet/pytorch, convert the model, then infer in dlib if you can't afford to link to clunky frameworks. We have the weight's visitor now so it suffices to train in darknet which handles all these annoying details relating to training yolo models. But up to you if you absolutely want to train in dlib. It sounds like a huge amount of work for little gain. Then you still have the problem that dlib is 25% slower than other frameworks on GPU.

pfeatherstone commented 3 years ago

And yeah, dlib doesn't have autograd, so you're going to have to hand-code the backward pass of CIOU which i imagine has the ugliest derivative ever computed.

pfeatherstone commented 3 years ago

However, if training yolo in dlib works, people might abandon darknet, which is a pile of crap, use dlib, then submit PRs and make DNN stuff in dlib better. So that might be a good enough reason to support yolo in dlib.

pfeatherstone commented 3 years ago

But i'm willing to bet, somewhere down the road, you will need a dlib::yolo_trainer object which handles all the subtle training "recipes" specific to yolo.

pfeatherstone commented 3 years ago

I'm not trying to put a downer on anything, but i think you've set yourself up for an enormous amount of work to make it work on a comparable level to the ultralytics repositories for example.

arrufat commented 3 years ago

Maybe I am a bit too optimistic, but I think it can definitely be done. Maybe we won't get state-of-the-art performance at first, but I think getting it to work acceptably well should be doable. Regarding the scheduler, it's already supported in dlib, in my latest experiments, I've tried it, but still nothing...

// Cosine scheduler with burn-in:
// - learning_rate is the highest learning rate value, e.g. 0.01
// - burnin: number of steps to linearly increase the learning rate
// - steps: maximum number of steps of the training session
const matrix<double> learning_rate_schedule = learning_rate * join_rows(                
    linspace(0, 1, burnin),
    ((1 + cos(pi / (steps - burnin) * linspace(0, steps - burnin, steps - burnin))) / 2)
) + std::numeric_limits<double>::epsilon();  // this prevents learning rates from being 0

// Tell the trainer to use it, instead of the default one
trainer.set_learning_rate_schedule(learning_rate_schedule);

We could also fiddle with the bias learning rates and the weight decays using network visitors, if that was absolutely necessary.

pfeatherstone commented 3 years ago

Oh wow, I did not know trainer.set_learning_rate_schedule existed. Looks very nice.

pfeatherstone commented 3 years ago

I like the fact you pass the entire learning rate schedule rather than a lambda. It seems silly most frameworks accept a lambda instead. Most of the time your schedule isn't stateful and in fact usually you want to plot it anyway so you end up running the lambda over something like linspace(0, 1, nepochs) anyway... So may as well compute it up front.

arrufat commented 3 years ago

What's nice, is that the output log of the trainer changes to look like this:

step#: 0     learning rate: 2.22045e-16  average loss: 0            percent complete: 0.00%
step#: 118   learning rate: 1.16012e-05  average loss: 4.11549      percent complete: 0.12%
step#: 237   learning rate: 2.35024e-05  average loss: 3.65662      percent complete: 0.24%
step#: 358   learning rate: 3.56036e-05  average loss: 2.96204      percent complete: 0.36%
step#: 479   learning rate: 4.77048e-05  average loss: 2.32915      percent complete: 0.48%
step#: 600   learning rate: 5.98060e-05  average loss: 1.81664      percent complete: 0.60% 
step#: 722   learning rate: 7.20072e-05  average loss: 1.36227      percent complete: 0.72%
step#: 843   learning rate: 8.41084e-05  average loss: 1.12913      percent complete: 0.84%
step#: 965   learning rate: 9.63096e-05  average loss: 0.900203     percent complete: 0.96%
step#: 1088  learning rate: 0.000108611  average loss: 0.783291     percent complete: 1.09%
step#: 1211  learning rate: 0.000120912  average loss: 0.630792     percent complete: 1.21%
step#: 1333  learning rate: 0.000133113  average loss: 0.583284     percent complete: 1.33%
step#: 1455  learning rate: 0.000145315  average loss: 0.515598     percent complete: 1.46%
step#: 1578  learning rate: 0.000157616  average loss: 0.457426     percent complete: 1.58%
step#: 1700  learning rate: 0.000169817  average loss: 0.442811     percent complete: 1.70%
step#: 1822  learning rate: 0.000182018  average loss: 0.405102     percent complete: 1.82%
step#: 1943  learning rate: 0.000194119  average loss: 0.377133     percent complete: 1.94%
step#: 2065  learning rate: 0.000206321  average loss: 0.373406     percent complete: 2.06%
step#: 2187  learning rate: 0.000218522  average loss: 0.371894     percent complete: 2.19%
step#: 2308  learning rate: 0.000230623  average loss: 0.301314     percent complete: 2.31%
step#: 2430  learning rate: 0.000242824  average loss: 0.345876     percent complete: 2.43%
step#: 2553  learning rate: 0.000255126  average loss: 0.308972     percent complete: 2.55%
step#: 2675  learning rate: 0.000267327  average loss: 0.375911     percent complete: 2.67%
Saved state to yolov3_sync

It would be even nicer if it added two extra columns at the right with the elapsed and remaining times :P

arrufat commented 3 years ago

I've found a source of slowness, not much, but still. Currently, the darknet weights visitor is doing this: https://github.com/dlib-users/darknet/blob/2502881eba1217b9c6a316c7fdf01e2b33573726/src/weights_visitor.h#L97-L101 to get the input number of channels for the convolution. However, that doesn't work when the input of the convolutional layer is an input_layer. As a hack, I tagged the input layer like this: https://github.com/dlib-users/darknet/blob/2502881eba1217b9c6a316c7fdf01e2b33573726/src/darknet.h#L151-L162

As I was playing with the YOLO loss, I realized I didn't need the tag1 anymore, since I was not loading any weights. After that, the VRAM fluctuation during training and inference disappeared (and I gained some fps, too).

Actually, we don't need to call l.subnet().get_output().k(), we can guess the number of input filters by doing:

tensor& params = l.layer_details().get_layer_params();
const long nf_in = (params.size() - nf) / nf / nr / nc;

That memory fluctuation has been bothering me for some time, and it's finally gone. However, I thought tag layers incurred no cost at runtime... Maybe @davisking has some explanation?

Reference: https://github.com/dlib-users/darknet/commit/78c05d2b30776d5dcaf1624bd401381b0de1838e

pfeatherstone commented 3 years ago

I'm convinced all the template bloat in dnn is adding latency

pfeatherstone commented 3 years ago

How is the visitor causing latency?It's only ever called once

arrufat commented 3 years ago

@pfeatherstone the visitor is not causing latency, the problem is that, for the visitor to work, I needed to add a tag layer to the input layer, otherwise the call to l.subnet().get_output().k() would not compile, since input layers don't have a get_output() method. That's why I wrapped into a tag layer.

That commit just removes the need to tag the input layer for the visitor to work, then you need to remove the tag1 from the input layers in the network definition, then the memory won't fluctuate any more and the network will run a bit faster.