Open Davmo049 opened 2 years ago
My take: if it's not used in popular existing modules, it could be optimized directly in core. If it's already used and compat is important, then an approx
string argument could be introduced akin to GELU https://github.com/pytorch/pytorch/issues/39853#issuecomment-1075686097 (although GELU approx seems to have problems reported by HF)
In principal this seems like it should help, because it saves a lot of expensive transcendental functions. Just out of curiosity I tried to use the default fuser and nvfuser (via torch.jit.script), which work for well for pointwise operations on GPUs, to implement this.
I found that the numerical results are (as one might expect) frequently different using torch.isclose
(e.g. the default tolerance for comparison). I also found that after fusion the performance, in terms of milliseconds per iteration, is at best the same as the native implementation, but not necessarily better. Of course, the fuser might not behave perfectly well here (and in particular I found that using pow(x, 2)
is faster than square
, which means the fusion might not be perfect and we don't know), which is something I'm going to follow up on separately.
Currently we implement mish via
x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)))
It seems relatively straightforward to change this line and send a PR to run all the tests if you're up for it @Davmo049, but I think the numerical stability issues are actually quite severe, because you're dividing by a potentially very large number.
🚀 The feature, motivation and pitch
Mish is a quite popular activation function. It is defined as a tanh of a softplus. Softplus is in turn defined as softplus(x) = log(1+exp(x)). Tanh is defined in terms of exponential functions. Logarithms and exponentials cancel out.
By doing algebraic manipulation one can prove that an equivalent function is mish(x) = x(1-2/(1+(1+exp(x))^2). Current implementation is mish(x) = x(tanh(softplus(x))
The backward implementation could also be manipulated in a similar way to possibly get a speedup.
Since this is only an optimization all code should already have tests
The potential benefit of this method is 1) Fewer "expensive" function invocations, (one expensive + a couple of standard ops vs. 2 expensive invocations) 2) Less branching code. Due to numerical instability softplus is approximated as a masked linear function for large inputs 3) Possibly better better numerical accuracy, each "expensive" invocation tend to have some rounding errors, therefore by using two of these might lead to larger errors.
The potential drawbacks are 1) Implementation would not look like the corresponding paper.
Suggestion: Rewrite: In aten/native/cuda/Activation.cu mish_kernel mish_backward_kernel
In caffe2/operators/mish_op.cc MishFunctor MishGradientOp
If you are interested I could do the implementation
Best David
Alternatives
Not doing anything
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345