AlexeyAB / darknet

YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )
http://pjreddie.com/darknet/
Other
21.65k stars 7.96k forks source link

Memory friendly/ faster Mish #5922

Open digantamisra98 opened 4 years ago

digantamisra98 commented 4 years ago

Since, I have observed some practitioners reporting issues regarding the slow training with Mish or it's high memory consumption, here is a curated list having Mish implementations which are significantly faster and cheaper in terms of memory (as close to inplace ReLU in some):

  1. [FastAI Autograd Mish Implementation]https://github.com/fastai/fastai2/blob/14e148049acd2ec26a3cc16ec79048c0aa94def1/fastai2/layers.py#L546) (Related Discussion - FastAI forums)(Inherited from Ross Wightman's implementation provided below)

  2. Mish-CUDA (Faster than native softplus implementation, requires CUDA SDK installation and only supports GPU profiling) Notable results from this implementation: Profiling over 100 runs after 10 warmup runs. Profiling on GeForce RTX 2070

    
    Testing on torch.float16:
    relu_fwd:      223.7µs ± 1.026µs (221.6µs - 229.2µs)
    relu_bwd:      312.1µs ± 2.308µs (307.8µs - 317.4µs)
    softplus_fwd:  342.2µs ± 38.08µs (282.4µs - 370.6µs)
    softplus_bwd:  488.5µs ± 53.75µs (406.0µs - 528.4µs)
    mish_pt_fwd:   658.8µs ± 1.467µs (655.9µs - 661.9µs)
    mish_pt_bwd:   1.135ms ± 4.785µs (1.127ms - 1.145ms)
    mish_cuda_fwd: 267.3µs ± 1.852µs (264.5µs - 274.2µs)
    mish_cuda_bwd: 345.6µs ± 1.875µs (341.9µs - 349.8µs)

Testing on torch.float32: relu_fwd: 234.2µs ± 621.8ns (233.2µs - 235.7µs) relu_bwd: 419.3µs ± 1.238µs (417.8µs - 426.0µs) softplus_fwd: 255.1µs ± 753.6ns (252.4µs - 256.5µs) softplus_bwd: 420.2µs ± 631.4ns (418.2µs - 421.9µs) mish_pt_fwd: 797.4µs ± 1.094µs (795.4µs - 802.8µs) mish_pt_bwd: 1.689ms ± 1.222µs (1.686ms - 1.696ms) mish_cuda_fwd: 282.9µs ± 876.1ns (281.1µs - 287.8µs) mish_cuda_bwd: 496.3µs ± 1.781µs (493.6µs - 503.0µs)

Testing on torch.float64: relu_fwd: 450.4µs ± 879.7ns (448.8µs - 456.4µs) relu_bwd: 834.2µs ± 925.8ns (832.3µs - 838.8µs) softplus_fwd: 6.370ms ± 2.348µs (6.362ms - 6.375ms) softplus_bwd: 2.359ms ± 1.276µs (2.356ms - 2.365ms) mish_pt_fwd: 10.11ms ± 2.806µs (10.10ms - 10.12ms) mish_pt_bwd: 4.897ms ± 1.312µs (4.893ms - 4.901ms) mish_cuda_fwd: 8.989ms ± 3.646µs (8.980ms - 9.007ms) mish_cuda_bwd: 10.92ms ± 3.966µs (10.91ms - 10.93ms)



3. [Ross Wightman's implementation](https://github.com/rwightman/gen-efficientnet-pytorch/blob/8795d3298d51ea5d993ab85a222dacffa8211f56/geffnet/activations/activations_autofn.py#L41)(used in FastAI)

4. [H-Mish](https://github.com/digantamisra98/H-Mish)(WIP)
digantamisra98 commented 4 years ago

@glenn-jocher @WongKinYiu @thomasbrandon @rwightman - tagging for any potential clarifications/ improvements.

digantamisra98 commented 4 years ago

For TensorFlow, Mish is already available in the TensorFlow Addons package

rwightman commented 4 years ago

@digantamisra98 FYI, the optimized mish in FastAI was cut & paste from my impl, no comment left

I have experimental version of your H-Mish (at least based on the current note, let me know if it changes) along with some refinements to my previous activations. It's pretty fast in this form, very little memory overhead as well. In testing so far it converges faster in early training than H-Swish but I'm not seeing much differentiation so far at the end of a long training w/ lots of aug, using RMSProp, EMA averaging.

https://github.com/rwightman/pytorch-image-models/blob/densenet_update_and_more/timm/models/layers/activations_me.py

digantamisra98 commented 4 years ago

@rwightman Thanks for the clarification! I am still working on H-Mish so nothing is definitive as of now, it does seem to work fairly well so far in my primitive testing. Could you rather provide some numbers of your benchmarks for your H-Mish implementation? Thanks!

rwightman commented 4 years ago

@digantamisra98 so far I've been trying it with MobileNetV3, which is normally a mix of relu and hard-swish. I changed the hard-swish to hard-mish. My best normal training run of MV3 was 75.768 top-1. Using same h-params, h-mish ended up around 75.56. Using a different set of hparams, higher initial LR and some regularization changes I hit 75.752 with h-mish. Running again with h-swish and the new hparams (although had to drop LR back a bit, h-mish seems to like higher LR) and we'll see where it lands...

rwightman commented 4 years ago

I feel I could probably push the LR even higher with h-mish, but these runs take some time (full imagenet, 500+ epochs) so yeah, only so many experiments can be run :)

arrufat commented 4 years ago

For the dlib implementation, here are the profiling results on a GeForce GTX 1080 Ti for a 64x3x224x224 tensor, over 100 runs after 10 warmup runs:

Benchmark code: https://github.com/arrufat/dlib-activations-benchmark Actual mish implementation: https://github.com/davisking/dlib/blob/master/dlib/cuda/cuda_dlib.cu#L1421-L1482

rwightman commented 4 years ago

@digantamisra98 running some autograd vs custom grad tests, something a bit off with my hard_mish, so need to check again and possibly re-run, will update you when that's sorted. EDIT: stupid mistake, cut & paste error not derivative issue, was using mish, heh... will rerun last session.

YashasSamaga commented 4 years ago

Results for inputs sampled uniformly from [-50, 50].

Implementation Time (float) Time (float4) L2norm (float)
relu 1.47ms 1.39ms Not applicable
mish_tb 2.06ms 1.68ms 0.000303575
mish_rw 2.10ms 1.78ms 0.000699724
mish_njuffa1 1.89ms 1.55ms 0.000766649
mish_njuffa2 3.14ms 2.73ms 2.48238e-05
mish_njuffa3 2.15ms 1.80ms 0.000132822
mish_aj1 2.54ms 2.31ms 0.000268734
mish_aj2 1.94ms 1.72ms NaN
mish_aj2_fastdiv 1.50ms 1.39ms NaN
mish_dlib 2.34ms 2.10ms 0.000699327
mish_ocv (old) 1.50ms 1.39ms 0.00012458
mish_ocv (new) 1.49ms 1.39ms 2.4583e-05

mish_ocv (old) is what Darknet and OpenCV currently use. mish_ocv (new) is the version used in the code below.

Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

IMPORTANT: The timings were measured using Nsight Compute and I manually copied them from the profiler to the table. There could be mistakes while copying as well as in the code. Please verify.

The L2 norm is computed against a reference calculated in double precision on CPU as x * std::tanh(std::log1p(std::exp(x))). This is not really the right way to compare but is ok for quick comparison I guess.

@armadillojim Your second version is giving NaNs for large inputs.

AlexeyAB commented 4 years ago

@digantamisra98 @rwightman @YashasSamaga Hi, are there any Mish or Mish-like backward-implementations where gradients can be calculated using the result of the activation-function instead of the input value of the activation-function, to reduce memory consumption? (Even with a strong slowdown.)

How can we do this for a sigmoid: gradient_sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)

digantamisra98 commented 4 years ago

@AlexeyAB you can do it by this way but again this might not be memory friendly, I can refactor it further though Capture

sech^2 and softplus can be further optimized to make it more memory friendly. But you still are input dependent on the denominator of the 2nd term

YashasSamaga commented 4 years ago

Mish gradient requires 2 LOAD + 1 STORE while ReLU gradient requires 1 LOAD + 1 STORE. Both are memory bound kernels and hence Mish gradient calculation can never be as fast as ReLU gradient or even get close without reducing memory accesses.

I think I can close the current gap between the limit and the gradient computation with some trickery but the gap presumably already very small (three memory ops give a lot of room for lots of computation). Should instead focus on reducing the bandwidth consumption.

Note that mish gradient is one beyond ~9.0f and nearly zero for large negative values. You can store the activation input in FP16. FP16 is inaccurate for large numbers but the inaccuracy doesn't affect the result (2049 stored as 2048 is going to give the same gradient! And yes, FP16 cannot store all integers beyond 2048!)

You can trade exponent for mantissa: 4 bit exponent and 11 bit mantissa. Convert activation input to this for storing during forward pass and convert it back to FP32 in gradient kernel. The gradient will be more accurate than FP16.


I tried with FP16. It turns out to be really inaccurate. Mish gradient has too many operations and small errors easily get amplified.

YashasSamaga commented 4 years ago

~2x faster gradient for ReLU, PReLU, Leaky, etc. at a small memory cost:

Extra memory required: input_size / 32 (if your input tensor is 64MB, you need 2MB workspace)

It essentially packs 32 inputs into a single unsigned int. The gradient kernel would have to load less data (which improves performance considerably as it is memory bound).

Forward pass:

__global__ void relu(float* output, const float* input, unsigned int* sign32, int n)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;

    bool sign = 0;
    if (i < n)
    {
        auto inp = input[i];
        sign = inp > 0;
        output[i] = sign ? inp : 0;
    }

    unsigned predicate = __ballot_sync(0xFFFFFFFF, sign);
    if (threadIdx.x % 32 == 0)
        sign32[i / 32] = predicate;
}

Backward pass:

template <int N>
__global__ void relu_grad_fast(float* dz, const unsigned int* sign32, int n)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int predicate = __brev(__ldg(&sign32[i / (32 / N)]));
    if (i < n)
    {
        const int laneid_byN = threadIdx.x % (32 / N) * N;

        if (N == 4)
        {
            float4 dy;
            dy.x = (predicate & (0x80000000 >> laneid_byN)) != 0;
            dy.y = (predicate & (0x80000000 >> (laneid_byN + 1))) != 0;
            dy.z = (predicate & (0x80000000 >> (laneid_byN + 2))) != 0;
            dy.w = (predicate & (0x80000000 >> (laneid_byN + 3))) != 0;
            reinterpret_cast<float4*>(dz)[i] = dy;
        }
        else if (N == 1)
        {
            dz[i] = (predicate & (0x80000000 >> laneid_byN)) != 0;
        }
        else
        {
            static_assert(N == 4 || N == 1, "");
        }
    }
}

sign32 is an array of unsigned int which has (num_elements + ((block_size + 31) / 32) * 32 - 1) / 32 elements. This uses one bit instead of one FP32 entry.

The / 32 can be optimized to a single shift if i is unsigned (more instructions otherwise to handle sign extension).


Code: https://gist.github.com/YashasSamaga/c694859eff9bcc596611abb85eaeb673

# Time (float) Time (float4)
relu_fwd 1.58ms N/A
relu_bwd_normal 1.49ms N/A
relu_bwd_fast 1.20ms 722us
YashasSamaga commented 4 years ago

ae there any Mish or Mish-like backward-implementations where gradients can be calculated using the result of the activation-function instead of the input value of the activation-function

@AlexeyAB I think that is impossible. Mish is not one-to-one and hence no inverse mapping exists.

Mish gives ~0.0f as output for both x = -50 and 0. The gradients are 0.0f and 0.6f respectively. If you pass just the mish output of 0.0f, the gradient function cannot tell whether it needs to return 0.0f or 0.6f.

You can salvage by passing just one bit of information: x is left or right of the minima. Mish output and this one bit of extra information is sufficient to get x.

YashasSamaga commented 4 years ago
__device__ float fast_grad(float x)
{
    auto e = __expf(x);
    auto n = e * e + 2 * e;

    float tsp;
    if (x <= -0.6f)
        tsp = __fdividef(n, n + 2);
    else
        tsp = 1 - __fdividef(2, n + 2);

    const float grad_sp = __fdividef(e, e + 1);

    const float grad_tsp = (1 - tsp*tsp) * grad_sp;
    const float grad = x * grad_tsp + tsp;

    return x > 10.5f ? 1 : grad;
}
# Time (float) Time(float4)
limit 2.06ms 2.04ms
mish_bwd_dn 2.45ms 2.34ms
mish_bwd_tb 2.23ms 2.12ms
mish_fast_grad 2.11ms 2.04ms

graphs_grad

Fast grad is more accurate than darknet near the minima (the relative error at the peak near -1 is least for fast grad).

Performing normal division instead of __fdividef(n, n + 2); will make the fast grad implementation more accurate than darknet but the non-vectorized version slows down a bit (the vectorized version is still at 2.04ms).

Code: https://gist.github.com/YashasSamaga/c3ee66732ff3c2b07cd48ea5bd7fb4e1