pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.21k stars 22.11k forks source link

Potential runtime optimization of Mish activation #75251

Open Davmo049 opened 2 years ago

Davmo049 commented 2 years ago

🚀 The feature, motivation and pitch

Mish is a quite popular activation function. It is defined as a tanh of a softplus. Softplus is in turn defined as softplus(x) = log(1+exp(x)). Tanh is defined in terms of exponential functions. Logarithms and exponentials cancel out.

By doing algebraic manipulation one can prove that an equivalent function is mish(x) = x(1-2/(1+(1+exp(x))^2). Current implementation is mish(x) = x(tanh(softplus(x))

The backward implementation could also be manipulated in a similar way to possibly get a speedup.

Since this is only an optimization all code should already have tests

The potential benefit of this method is 1) Fewer "expensive" function invocations, (one expensive + a couple of standard ops vs. 2 expensive invocations) 2) Less branching code. Due to numerical instability softplus is approximated as a masked linear function for large inputs 3) Possibly better better numerical accuracy, each "expensive" invocation tend to have some rounding errors, therefore by using two of these might lead to larger errors.

The potential drawbacks are 1) Implementation would not look like the corresponding paper.

Suggestion: Rewrite: In aten/native/cuda/Activation.cu mish_kernel mish_backward_kernel

In caffe2/operators/mish_op.cc MishFunctor MishGradientOp

If you are interested I could do the implementation

Best David

Alternatives

Not doing anything

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

vadimkantorov commented 2 years ago

My take: if it's not used in popular existing modules, it could be optimized directly in core. If it's already used and compat is important, then an approx string argument could be introduced akin to GELU https://github.com/pytorch/pytorch/issues/39853#issuecomment-1075686097 (although GELU approx seems to have problems reported by HF)

cpuhrsch commented 2 years ago

In principal this seems like it should help, because it saves a lot of expensive transcendental functions. Just out of curiosity I tried to use the default fuser and nvfuser (via torch.jit.script), which work for well for pointwise operations on GPUs, to implement this.

I found that the numerical results are (as one might expect) frequently different using torch.isclose (e.g. the default tolerance for comparison). I also found that after fusion the performance, in terms of milliseconds per iteration, is at best the same as the native implementation, but not necessarily better. Of course, the fuser might not behave perfectly well here (and in particular I found that using pow(x, 2) is faster than square, which means the fusion might not be perfect and we don't know), which is something I'm going to follow up on separately.

Currently we implement mish via

x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)))

It seems relatively straightforward to change this line and send a PR to run all the tests if you're up for it @Davmo049, but I think the numerical stability issues are actually quite severe, because you're dividing by a potentially very large number.