AlexeyAB / darknet

YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )
http://pjreddie.com/darknet/
Other
21.73k stars 7.96k forks source link

Hard mish #6209

Open AlexeyAB opened 4 years ago

AlexeyAB commented 4 years ago

@digantamisra98 Hi,

So hard_mish = min(2, max(0, x+2))*x/2 https://github.com/digantamisra98/H-Mish http://fooplot.com/#W3sidHlwZSI6MCwiZXEiOiJtaW4oMixtYXgoMCx4KzIpKSp4LzIiLCJjb2xvciI6IiMwMDAwMDAifSx7InR5cGUiOjEwMDB9XQ--

But what is the optimal code for Inference and Gradinent?

digantamisra98 commented 4 years ago

@AlexeyAB I'm still working on it to provide the most memory friendly and fast approximation for Mish. As of now I don't have the optimal code ready since I'm a bit occupied in completing a couple of other projects. But it should be ready by this month's end / next month's start. I'll also then benchmark on ImageNet just to be sure that the performance replicates as well.

AlexeyAB commented 4 years ago

Thanks!

PallHaraldsson commented 4 years ago

I forked, and made as fast Hard-mish as ReLU, with different shapes, likely more accurate:

https://github.com/PallHaraldsson/H-Mish/blob/master/README.md

AlexeyAB commented 4 years ago

@YashasSamaga Hi, Did you try to optimize forward and gradient of Hard-mish on CUDA?

YashasSamaga commented 4 years ago

Files: https://github.com/YashasSamaga/ConvolutionBuildingBlocks/tree/master/hmish

hmish_infer.cu

Device: GTX 1050

Implementation Time (float) Time (float4) L2norm
relu 1.43ms 1.39ms N/A
hmish 1.45ms 1.39ms 5.27584e-05

hmish_train.cu

Device: GTX 1050

Implementation Time (float) Time (float4) L2norm
relu_grad 1.50ms N/A N/A
hmish_train_fwd 1.61ms N/A N/A
hmish_bwd 1.60ms 1.38ms 0.00467884

The activation can be in short expressed as (x/2).min(2, max(0, x+2)). On inlining min and max, we can reduce the activation to:

if (x > 0)
    return x;
if (x > -2)
    return x * x / 2 + x;
return 0;

The corresponding subgradients for each range can be expressed as:

if (x > 0)
   return 1;
if (x > -2)
   return x + 1;
return 0;

The activation function is convex and the minima is f(-1) = -0.5. We can safely split the function into two parts: right and left of the minima. The partial inverses for the two sides is fairly simple.

y = x * x / 2 + x
  = 0.5 * (x^2 + 2x)
  = 0.5 * (x^2 + 2x + 1 - 1)
  = 0.5 * ((x + 1)^2 - 1)
2y + 1 = (x + 1)^2

Note that the gradient of the non-linear region is x + 1. Therefore, the gradient is sqrt(2y + 1). It's positive for the right side and negative for the left side.

Hence, if we save "one bit" of information during the forward pass indicating the which side of the domain the input is in, we can compute the gradient using the activation output and this one bit of information.

We can also compute the gradient with just the activation input if that's available. In cases where the activation output is already stored and activation input isn't, we can reduce the memory requirements by 32x by using individual bits of a 32-bit number instead of storing the activation input.

digantamisra98 commented 4 years ago

Thanks for all the work @YashasSamaga. I have updated Mish repo to feature links to your repo for faster implementation. If you'd like to push your H-Mish code to my repo, please submit a PR. Additionally, it would be great if you could help on checking this - https://github.com/shogun-toolbox/shogun/pull/4812. I want to get this PR done. Thanks!

PallHaraldsson commented 4 years ago

This one-bit trick is neat. Is it used for any other activation function? I'm a bit rusty, and only recalled gradients, so I didn't follow on first reading (why "subgradient"?), but on second reading all clear (I could still make a PR to improve clarity a bit).

"The activation function is convex", you mean the non-linear part only, and for my third-order polynominal hard-mish, it seems you could do the same (or similar functions with a minima), while I didn't do the math (I was thinking it might involve a cube root, and I at least time it as fast as a square root).

For my hard_mish2 the derivative is (x+1)(3x+1) and for my hard_mish it's ((x+4)(3x+4))/16. You can at least always derive the gradient from y and this one bit, but I fail to see a simple function for it. Since it's such a small domain(s) you could have an approximation for it or two for either side.

YashasSamaga commented 4 years ago

Is it used for any other activation function?

It can be used for activation functions like ReLU, Leaky ReLU, etc. It can be done for mish too but I haven't been able to find the partial inverses.

Someone at CSE also approximated mish to a third-order polynomial here.

I have plotted the functions here.

For my hard_mish2 the derivative is (x+1)(3x+1) and for my hard_mish it's ((x+4)(3x+4))/16. You can at least always derive the gradient from y and this one bit, but I fail to see a simple function for it.

Same. I am not able to find the partial inverses.


@digantamisra98 I didn't understand what you want me to do with the shogun PR.

digantamisra98 commented 4 years ago

Oh, I was wondering if you'd point out how that PR can be optimized further maybe?

AlexeyAB commented 4 years ago

@digantamisra98 @YashasSamaga Hi, What is the best and fastest implementation of MISH and hard-MISH for Pytorch currently?

digantamisra98 commented 4 years ago

@AlexeyAB There is some implementation for Mish which has popped up in a branch of PyTorch official repository (seems to be by someone who's interning at Facebook since the PR was merged without review and was pulled from Phabricator). Link to PR I haven't validated it's performance. Regarding other implementations in PyTorch which are validated to be fast and memory cheap and mirror the learning performance of intrinsic Mish, these are the only 2 that I know of:

  1. Mish CUDA
  2. Memory Efficient Experimental version of Mish by Ross Wightman
digantamisra98 commented 4 years ago

@AlexeyAB @WongKinYiu you would be interested in these new findings on Mish - https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu/53299/648?u=diganta

justsolo-smith commented 4 years ago

@AlexeyAB @digantamisra98 @YashasSamaga I change the mish in the yolov4.cfg file to hard-mish. The hard-mish effect is poor and the inference speed is slow. Have you ever done any experiments on hard-mish? Or are your results different from mine? Does hard-mish need to optimize the network to some extent?

YashasSamaga commented 4 years ago

@justsolo-smith Mish is as fast as hard mish for inference. Hard mish gradient computation can be made 2x faster and consume 32x less memory but darknet is using the unoptimized version where hard mish gradient is nearly as slow as mish gradient.

I think there is no performance difference between hard mish and mish in darknet repository.

justsolo-smith commented 4 years ago

@YashasSamaga Mish is as fast as hard mish for inference?Shouldn't hard - mish faster than mish for inference? But I'm using Darknet's default hard-mish inferences that are slower than mish... Is it because Darknet does not optimize hard-mish on the CUDA? Did you test it on Darknet?

YashasSamaga commented 4 years ago

Mish is as fast as hard mish for inference?Shouldn't hard - mish faster than mish for inference?

The optimized mish and hmish (both forward and backward pass) are memory-bound operations. Fetching the data from memory and storing the result determines the bulk of the execution time. In fact, it's so pronounced that the compute units on your GPU stay idle (the optimized mish uses only 20% of my GPU's compute resources) waiting for the memory subsystem during mish forward pass. You could add more compute work to mish and see no loss in performance.

But I'm using Darknet's default hard-mish inferences that are slower than mish...

By how much? I think there should be no difference.

Is it because Darknet does not optimize hard-mish on the CUDA?

Darknet's hmish inference should be as fast as mish. The darknet's gradient for hmish should be slightly faster than mish gradient.

Did you test it on Darknet?

No but I have tested the kernels darknet is using in a standalone program.

mrezq651 commented 4 years ago

@digantamisra98 hell. hope you are doing well. please notify me when you post the optimized code my personal email is mrezq651@gmail.com

digantamisra98 commented 4 years ago

@mrezq651 Sure, I will. Unfortunately due to lack of resources, I haven't been able to do proper testing yet which is why I haven't released it yet but I'm working on it.

Goru1890 commented 3 years ago

Does Mish improve YoloV4 detection?

WongKinYiu commented 3 years ago

yes, you can take a look comparison in the model zoo.