aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 153 forks source link

Add numba implementations for LogSumExp (reference implementation exists) #404

Open twiecki opened 3 years ago

twiecki commented 3 years ago

@aseyboldt provided an implementation we can probably use here as a starting point:

@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp(p, out):
    n, m = p.shape
    assert len(out) == n
    assert out.ndim == 1
    assert p.ndim == 2

    for i in numba.prange(n):
        res = 0
        for j in range(m):
            res += np.exp(p[i, j])
        res = np.log(res)
        out[i] = res

@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp_grad(p, out, dout, dp):
    n, m = p.shape
    assert len(out) == n
    assert out.ndim == 1
    assert len(dout) == n
    assert dout.ndim == 1
    assert dp.shape == p.shape

    for i in numba.prange(n):
        for j in range(m):
            dp[i, j] = np.exp(p[i, j] - out[i]) * dout[i]

class LogSumExp(theano.graph.op.Op):
    """Custom softmax, done through logsumexp"""

    itypes = [tt.dmatrix]
    otypes = [tt.dvector]

    def perform(self, node, inputs, outputs):
        x, = inputs
        n, m = x.shape
        out = np.zeros(n, dtype=x.dtype)
        numba_logsumexp(x, out)
        outputs[0][0] = out

    def grad(self, inputs, grads):
        x, = inputs
        dout, = grads
        logsumexp = self(x)
        return [LogSumExpGrad()(x, logsumexp, dout)]

class LogSumExpGrad(theano.graph.op.Op):
    """Joint operator"""

    itypes = [tt.dmatrix, tt.dvector, tt.dvector]
    otypes = [tt.dmatrix]

    def perform(self, node, inputs, outputs):
        p, out, dout = inputs
        dp = np.zeros(p.shape, dtype=p.dtype)
        numba_logsumexp_grad(p, out, dout, dp)
        outputs[0][0] = dp

logsumexp = LogSumExp()
twiecki commented 3 years ago

There's also a numerically stable one:

@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp_stable(p, out):
    n, m = p.shape
    assert len(out) == n
    assert out.ndim == 1
    assert p.ndim == 2

    for i in numba.prange(n):
        p_max = np.max(p[i])
        res = 0
        for j in range(m):
            res += np.exp(p[i, j] - p_max)
        res = np.log(res) + p_max
        out[i] = res
brandonwillard commented 3 years ago

For anyone interested in implementing this, we first need to consider whether or not distinct Ops are required and, if they are, whether or not those new Ops should be implemented as scalar Ops that are broadcasted via Elemwise. This is especially important given that these implementations appear to put strong restrictions on the number of dimensions of their inputs.

For instance, these reference implementations manually specify gradient computations, but that shouldn't be necessary. The same may be true for the manual loops in the Op.perform implementations.

(NB: the parallelization features should be [made] configurable through the Numba backend, as well, so that really shouldn't be a limiting factor.)

Let's start by doing comparisons with pure Aesara implementations of log-sum-exp (e.g. PyMC3's and/or pymc3-hmm's) that are compiled with the Numba backend.

In general, we want to make basic compositions of simple, pre-existing Ops as performant as possible, and only create custom Ops when it's absolutely necessary or dramatically more performant. In the latter case, such Ops would ideally only need to be used by optimizations that replace the less performant sub-graphs composed of pre-existing Ops (i.e. optimize automatically behind the scenes).

fanshi118 commented 3 years ago

Tested the above, here are some preliminary results.

import timeit

import numba
import numpy as np
import aesara
import aesara.tensor as at

test_data = np.random.normal(size=(3, 3))

# method 1: custom Op
X = at.matrix("X")
y = logsumexp(X)
y_fn = aesara.function([X], y)
_ = y_fn(test_data)

%timeit _ = y_fn(test_data)
## 26.5 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

def logsumexp2(x, axis=None, keepdims=True):
    # Adapted from https://github.com/Theano/Theano/issues/1563
    x_max = at.max(x, axis=axis, keepdims=True)
    x_max = at.switch(at.isinf(x_max), 0, x_max)
    res = at.log(at.sum(at.exp(x - x_max), axis=axis, keepdims=True)) + x_max
    return res if keepdims else res.squeeze()

# method 2: regular Op (numba-compiled)
y2 = logsumexp2(X, axis=1, keepdims=False)
y2_fn = aesara.function([X], y2, mode="NUMBA")
_ = y2_fn(test_data)

%timeit _ = y2_fn(test_data)
## 16 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

It seems that the numba-compiled Aesara-only Op outperforms the custom Op in this very test case.

ricardoV94 commented 3 years ago

Related: #465

brandonwillard commented 3 years ago

The performance difference is probably due to the use of prange in the custom Op. If we increase the size of test_data, I'm guessing we'll see the custom Op start to outperform the Aesara-only Op.

fanshi118 commented 3 years ago

The performance difference is probably due to the use of prange in the custom Op. If we increase the size of test_data, I'm guessing we'll see the custom Op start to outperform the Aesara-only Op.

Seems like it.

array size runtime per loop - custom Numba Op Numba-compiled Aesara graph C/Python-compiled Aesara graph
3 x 3 26.5 µs ± 1.66 µs (10000 loops) 16 µs ± 126 ns (100000 loops) 9.88 µs ± 55.1 ns (100000 loops)
25 x 25 31.1 µs ± 1.67 µs (10000 loops) 28.4 µs ± 3.47 (10000 loops) 16.7 µs ± 469 ns (100000 loops)
100 x 100 41.7 µs ± 1.83 µs (10000 loops) 130 µs ± 1.32 µs (10000 loops) 112 µs ± 4.07 µs (10000 loops)
1000 x 1000 1.21 ms ± 120 µs (1000 loops) 20.4 ms ± 212 µs (10 loops) 11.4 ms ± 418 µs (100 loops)
10000 x 10000 99.7 ms ± 7.99 (10 loops) 4.47 s ± 83.9 ms (1 loop) 1.43 s ± 12.9 ms (1 loop)
brandonwillard commented 3 years ago

Well, this provides strong motivation for #409 (e.g. setting parallel=True in our invocations of numba.njit)! There might also be a few places where fastmath=True can be set.

Aside from that, let's try to determine how/why Numba takes almost twice as long for small arrays (e.g. is it jumping into object mode for some reason, are we adding unnecessary np.[as]array calls or making copies, etc.).

If anyone is interested, start by reviewing this Numba example and apply a similar approach to our Numba-compiled graphs.

NB: the first case might fall into the category of short-running functions with little room for significant improvement, but I kinda doubt that.

brandonwillard commented 3 years ago

For anyone interested in running @fanshi118's comparison, here's a Gist that reproduces the results. Use timeit_data.to_markdown() to produce a Markdown table that can be embedded in an issue.

Here's what I get:

data shape custom Op Aesara graph (Numba) Aesara graph (C)
(3, 3) 21 µs ± 1.29 µs (100000) 16.4 µs ± 216 ns (10000) 10.4 µs ± 114 ns (100000)
(25, 25) 25.2 µs ± 637 ns (10000) 27 µs ± 167 ns (10000) 19.4 µs ± 845 ns (100000)
(100, 100) 64.2 µs ± 3.25 µs (10000) 203 µs ± 21.2 µs (1000) 175 µs ± 10.5 µs (1000)
(1000, 1000) 3.42 ms ± 168 µs (100) 25.9 ms ± 1 ms (10) 15.5 ms ± 419 µs (100)
(10000, 10000) 320 ms ± 14.2 ms (1) 4.82 s ± 324 ms (1) 1.61 s ± 28.4 ms (1)
brandonwillard commented 3 years ago

Also, for those who want to investigate the Numba-compiled function itself, here's how it can be obtained:

import inspect

fn, *_ = aesara_numba_fn.maker.linker.make_all()
cl_vars = inspect.getclosurevars(fn)
thunk = cl_vars.nonlocals['thunks'][0]
thunk_signature = inspect.signature(thunk)
fgraph_jit = thunk_signature.parameters['fgraph_jit'].default

With fgraph_jit, one can perform standard low-level LLVM investigations:

from aesara.link.numba.dispatch import get_numba_type

numba_input_sig = tuple(get_numba_type(i.type) for i in aesara_numba_fn.maker.fgraph.inputs)

fgraph_jit.compile(numba_input_sig)

fgraph_jit.inspect_types()
Output

```python numba_funcified_fgraph (array(float64, 2d, A),) -------------------------------------------------------------------------------- # File: /tmp/user/1000/tmpp3l2be04 # --- LINE 2 --- def numba_funcified_fgraph(X): # --- LINE 3 --- # label 0 # X = arg(0, name=X) :: array(float64, 2d, A) # $2load_global.0 = global(careduce_elemwise_maximum: CPUDispatcher()) :: type(CPUDispatcher()) # auto_16690 = call $2load_global.0(X, func=$2load_global.0, args=[Var(X, tmpp3l2be04:3)], kws=(), vararg=None) :: (array(float64, 2d, A),) -> array(float64, 1d, C) # del $2load_global.0 auto_16690 = careduce_elemwise_maximum(X) # --- LINE 4 --- # $10load_global.3 = global(dimshuffle: CPUDispatcher(.dimshuffle at 0x7f3508138ef0>)) :: type(CPUDispatcher(.dimshuffle at 0x7f3508138ef0>)) # auto_16638 = call $10load_global.3(auto_16690, func=$10load_global.3, args=[Var(auto_16690, tmpp3l2be04:3)], kws=(), vararg=None) :: (array(float64, 1d, C),) -> array(float64, 2d, C) # del auto_16690 # del $10load_global.3 auto_16638 = dimshuffle(auto_16690) # --- LINE 5 --- # $18load_global.6 = global(elemwise_numba_funcified_fgraph: ) :: Function() # $22load_global.8 = global(auto_16670: [[0]]) :: readonly array(int8, 2d, C) # auto_16777 = call $18load_global.6(auto_16638, $22load_global.8, func=$18load_global.6, args=[Var(auto_16638, tmpp3l2be04:4), Var($22load_global.8, tmpp3l2be04:5)], kws=(), vararg=None) :: (array(float64, 2d, C), readonly array(int8, 2d, C)) -> array(float64, 2d, C) # del auto_16638 # del $22load_global.8 # del $18load_global.6 auto_16777 = elemwise_numba_funcified_fgraph(auto_16638, auto_16670) # --- LINE 6 --- # $28load_global.10 = global(elemwise_subtract: ) :: Function() # auto_16643 = call $28load_global.10(X, auto_16777, func=$28load_global.10, args=[Var(X, tmpp3l2be04:3), Var(auto_16777, tmpp3l2be04:5)], kws=(), vararg=None) :: (array(float64, 2d, A), array(float64, 2d, C)) -> array(float64, 2d, C) # del X # del $28load_global.10 auto_16643 = elemwise_subtract(X, auto_16777) # --- LINE 7 --- # $38load_global.14 = global(dimshuffle1: CPUDispatcher(.dimshuffle at 0x7f34bb1ec7a0>)) :: type(CPUDispatcher(.dimshuffle at 0x7f34bb1ec7a0>)) # auto_16658 = call $38load_global.14(auto_16777, func=$38load_global.14, args=[Var(auto_16777, tmpp3l2be04:5)], kws=(), vararg=None) :: (array(float64, 2d, C),) -> array(float64, 1d, C) # del auto_16777 # del $38load_global.14 auto_16658 = dimshuffle1(auto_16777) # --- LINE 8 --- # $46load_global.17 = global(careduce_elemwise_maximum1: CPUDispatcher()) :: type(CPUDispatcher()) # auto_16689 = call $46load_global.17(auto_16643, func=$46load_global.17, args=[Var(auto_16643, tmpp3l2be04:6)], kws=(), vararg=None) :: (array(float64, 2d, C),) -> array(float64, 1d, C) # del $46load_global.17 auto_16689 = careduce_elemwise_maximum1(auto_16643) # --- LINE 9 --- # $54load_global.20 = global(dimshuffle2: CPUDispatcher(.dimshuffle at 0x7f34bb0ee0e0>)) :: type(CPUDispatcher(.dimshuffle at 0x7f34bb0ee0e0>)) # auto_16673 = call $54load_global.20(auto_16689, func=$54load_global.20, args=[Var(auto_16689, tmpp3l2be04:8)], kws=(), vararg=None) :: (array(float64, 1d, C),) -> array(float64, 2d, C) # del $54load_global.20 auto_16673 = dimshuffle2(auto_16689) # --- LINE 10 --- # $62load_global.23 = global(elemwise_numba_funcified_fgraph1: ) :: Function() # auto_16789 = call $62load_global.23(auto_16643, auto_16673, func=$62load_global.23, args=[Var(auto_16643, tmpp3l2be04:6), Var(auto_16673, tmpp3l2be04:9)], kws=(), vararg=None) :: (array(float64, 2d, C), array(float64, 2d, C)) -> array(float64, 2d, C) # del auto_16673 # del auto_16643 # del $62load_global.23 auto_16789 = elemwise_numba_funcified_fgraph1(auto_16643, auto_16673) # --- LINE 11 --- # $72load_global.27 = global(careduce_elemwise_add: CPUDispatcher()) :: type(CPUDispatcher()) # auto_16681 = call $72load_global.27(auto_16789, func=$72load_global.27, args=[Var(auto_16789, tmpp3l2be04:10)], kws=(), vararg=None) :: (array(float64, 2d, C),) -> array(float64, 1d, C) # del auto_16789 # del $72load_global.27 auto_16681 = careduce_elemwise_add(auto_16789) # --- LINE 12 --- # $80load_global.30 = global(elemwise_numba_funcified_fgraph2: ) :: Function() # auto_16804 = call $80load_global.30(auto_16689, auto_16681, auto_16658, func=$80load_global.30, args=[Var(auto_16689, tmpp3l2be04:8), Var(auto_16681, tmpp3l2be04:11), Var(auto_16658, tmpp3l2be04:7)], kws=(), vararg=None) :: (array(float64, 1d, C), array(float64, 1d, C), array(float64, 1d, C)) -> array(float64, 1d, C) # del auto_16689 # del auto_16681 # del auto_16658 # del $80load_global.30 auto_16804 = elemwise_numba_funcified_fgraph2(auto_16689, auto_16681, auto_16658) # --- LINE 13 --- # $94build_tuple.36 = build_tuple(items=[Var(auto_16804, tmpp3l2be04:12)]) :: UniTuple(array(float64, 1d, C) x 1) # del auto_16804 # $96return_value.37 = cast(value=$94build_tuple.36) :: UniTuple(array(float64, 1d, C) x 1) # del $94build_tuple.36 # return $96return_value.37 return (auto_16804,) ================================================================================ ```

Since functions like fgraph_jit are compiled FunctionGraphs, it's likely that one will need to inspect the functions called within the compiled function graph. Those can be obtained through similar means:

# Take a look at the Python function that was compiled by Numba
print(inspect.getsource(fgraph_jit.py_func))
# def numba_funcified_fgraph(X):
#     auto_16690 = careduce_elemwise_maximum(X)
#     auto_16638 = dimshuffle(auto_16690)
#     auto_16777 = elemwise_numba_funcified_fgraph(auto_16638, auto_16670)
#     auto_16643 = elemwise_subtract(X, auto_16777)
#     auto_16658 = dimshuffle1(auto_16777)
#     auto_16689 = careduce_elemwise_maximum1(auto_16643)
#     auto_16673 = dimshuffle2(auto_16689)
#     auto_16789 = elemwise_numba_funcified_fgraph1(auto_16643, auto_16673)
#     auto_16681 = careduce_elemwise_add(auto_16789)
#     auto_16804 = elemwise_numba_funcified_fgraph2(auto_16689, auto_16681, auto_16658)
#     return (auto_16804,)

# Get the `careduce_elemwise_maximum` function called within `numba_funcified_fgraph`
careduce_elemwise_maximum = inspect.getclosurevars(fgraph_jit.py_func).globals[
    "careduce_elemwise_maximum"
]

print(inspect.getsource(careduce_elemwise_maximum.py_func))
# def careduce_elemwise_maximum(X):
#     axis_0_res = careduce_axes_fns[0](X)
#     return axis_0_res
fanshi118 commented 3 years ago
Tried turning parallel and fastmath off separately and got the following results: data shape custom Op custom Op (fastmath only) custom Op (parallel only) Aesara graph (Numba) Aesara graph (C)
(3, 3) 25.1 µs ± 640 ns (10000) 9.96 µs ± 131 ns (100000) 34.5 µs ± 19.5 µs (1) 15.1 µs ± 115 ns (10000) 9.71 µs ± 53.1 ns (100000)
(25, 25) 27.7 µs ± 712 ns (10000) 14.2 µs ± 237 ns (10000) 27.2 µs ± 321 ns (10000) 24.1 µs ± 504 ns (10000) 15.8 µs ± 138 ns (100000)
(100, 100) 38.8 µs ± 386 ns (10000) 67.5 µs ± 971 ns (10000) 40.6 µs ± 2.2 µs (10000) 128 µs ± 1.31 µs (10000) 107 µs ± 1.8 µs (10000)
(1000, 1000) 1.09 ms ± 21.2 µs (1000) 5.73 ms ± 176 µs (100) 1.35 ms ± 113 µs (1000) 20.1 ms ± 660 µs (10) 10.3 ms ± 283 µs (100)
(10000, 10000) 114 ms ± 12.7 ms (10) 632 ms ± 74.6 ms (1) 133 ms ± 16.6 ms (10) 4.32 s ± 19.7 ms (1) 1.36 s ± 7.81 ms (1)
twiecki commented 3 years ago

@brandonwillard Do you think inlining will get the numbers closer?

brandonwillard commented 3 years ago

@brandonwillard Do you think inlining will get the numbers closer?

I've already inlined most of the unnecessary function calls created by the backend (at least since #476). Given the way that the latency scales with shape, I'm guessing there are still a few more array copies being made somewhere (e.g. something isn't being in-placed).

fanshi118 commented 3 years ago

Here's a Gist for running the LSE function in three separate steps (by isolating at.max and at.switch). Comparison between Numba- and C-compilation looks like the following

data shape Aesara graph (Numba) - max Aesara graph (C) - max Aesara graph (Numba) - switch Aesara graph (C) - switch Aesara graph (Numba) - LSE Aesara graph (C) - LSE
(3, 3) 11.6 µs ± 142 ns (100000) 7.13 µs ± 84.9 ns (100000) 10.2 µs ± 27.8 ns (100000) 6.77 µs ± 61.8 ns (100000) 16.7 µs ± 252 ns (100000) 11.6 µs ± 107 ns (100000)
(25, 25) 12.7 µs ± 101 ns (100000) 7.72 µs ± 76.8 ns (100000) 10.3 µs ± 113 ns (100000) 6.67 µs ± 72.4 ns (100000) 25.9 µs ± 432 ns (10000) 17.5 µs ± 103 ns (100000)
(100, 100) 24.3 µs ± 499 ns (10000) 18.2 µs ± 157 ns (100000) 10.5 µs ± 146 ns (100000) 6.79 µs ± 80.8 ns (100000) 144 µs ± 799 ns (10000) 104 µs ± 4.21 µs (10000)
(1000, 1000) 1.82 ms ± 20.9 µs (1000) 862 µs ± 11.1 µs (1000) 12.7 µs ± 240 ns (100000) 7.85 µs ± 309 ns (100000) 13.6 ms ± 89.4 µs (100) 8.52 ms ± 38.7 µs (100)
(10000, 10000) 925 ms ± 4.62 ms (1) 113 ms ± 494 µs (10) 35.1 µs ± 373 ns (10000) 14.1 µs ± 75.5 ns (100000) 3.39 s ± 14.5 ms (1) 1.3 s ± 5.65 ms (1)
twiecki commented 3 years ago

@fanshi118 Thanks, those are fascinating. I was hoping numba to be faster. Is this with fast_math and parallelization?

brandonwillard commented 3 years ago

There appears to be some latency introduced by Aesara's Function interface. For example, here's a simple Numba-only implementation of at.switch(at.isinf(x), 0, x) (i.e. one of the subgraphs in the log-sum-exp graph) :

@numba.vectorize
def numba_switch_isinf(x):
    if np.isinf(x):
        return 0
    else:
        return x

If we compare this with the Aesara C-compiled function, the Aesara Numba-compiled function, and the Aesara-generated Numba function called directly (i.e. without Function), we get the following:

data shape Numba Aesara-C Aesara-Numba Aesara-Numba (direct)
(3, 3) 328 ns ± 1.72 ns (1000000) 6.42 µs ± 76.1 ns (100000) 9.97 µs ± 127 ns (100000) 508 ns ± 2.92 ns (1000000)
(100, 100) 3.06 µs ± 21.1 ns (100000) 18.7 µs ± 134 ns (100000) 16.6 µs ± 251 ns (100000) 3.75 µs ± 50 ns (100000)
(1000, 1000) 784 µs ± 39.6 µs (1000) 1.16 ms ± 20.1 µs (1000) 823 µs ± 12.8 µs (1000) 789 µs ± 4.18 µs (1000)
(5000, 5000) 34.3 ms ± 610 µs (10) 46.3 ms ± 1.2 ms (10) 89.3 ms ± 3.4 ms (10) 97 ms ± 3.12 ms (10)
(10000, 10000) 214 ms ± 11 ms (1) 268 ms ± 7.68 ms (1) 334 ms ± 1.73 ms (1) 329 ms ± 2.14 ms (1)

We might need to adjust our timeit approach, because I'm not sure how trustworthy the larger shape statistics are, but—as previously stated—it looks like the smaller shape results are demonstrating a Function-induced call overhead in the Aesara-C and Aesara-Numba results.

Regardless, someone should take a closer look at the Aesara-generated Numba function and see if there's anything that could be causing a drop in performance for larger shapes.

brandonwillard commented 3 years ago

It's worth noting that we really shouldn't focus on the fast_math and/or parallel options; those are simple things to add, and they don't address the persistent underlying sources of latency (i.e. the ones that will still be present when those options are enabled).

Instead, we need to focus on how well the Aesara-generated Numba functions perform compared to Aesara's C results and equivalent Numba-only functions.

In other words, we need to make Aesara's Numba-compiled functions at least as fast as the current C backend and—ideally—on par with a handwritten Numba function.

brandonwillard commented 3 years ago

Here are the results of the log-sum-exp comparison after #498 (parallel and fastmath are off in all cases):

data shape Numba Aesara-C Aesara-Numba Aesara-Numba (direct)
(3, 3) 635 ns ± 7.91 ns (1000000) 11.8 µs ± 269 ns (100000) 16.7 µs ± 459 ns (100000) 3.53 µs ± 45.3 ns (100000)
(25, 25) 6.87 µs ± 55 ns (100000) 22.1 µs ± 518 ns (10000) 32.2 µs ± 529 ns (10000) 17.8 µs ± 248 ns (100000)
(100, 100) 95.4 µs ± 1.34 µs (10000) 160 µs ± 3.81 µs (10000) 208 µs ± 4.03 µs (1000) 202 µs ± 2.62 µs (10000)
(1000, 1000) 8.62 ms ± 51.8 µs (100) 14.8 ms ± 363 µs (100) 20.1 ms ± 219 µs (100) 20.4 ms ± 700 µs (10)
(10000, 10000) 826 ms ± 11.4 ms (1) 1.48 s ± 18.1 ms (1) 4.04 s ± 43.9 ms (1) 4.02 s ± 53.1 ms (1)
twiecki commented 3 years ago

Getting there. What do you think of where the other bottlenecks are?

brandonwillard commented 3 years ago

We need to continue scrutinizing the generated Numba functions piece by piece.

My guess is that there's still an in-place operation missing and/or an erroneous array copy taking place within the generated functions.

brandonwillard commented 3 years ago

@kc611 and I were just going over this issue (via #529) and, while developing a clearer MWE (see here), we noticed that the axis argument—and the resulting transpose and loop in careduce_axis—is likely the main source of decreased Numba performance compared to NumPy.

twiecki commented 3 years ago

@brandonwillard Does that mean we can keep using vectorize?

brandonwillard commented 3 years ago

@brandonwillard Does that mean we can keep using vectorize?

Possibly, but all the things that @stuartarchibald and others mentioned here still apply; however, from our code-generation, it seems like we could at least be doing better about keeping arrays/operations "contiguous-friendly".

46319943 commented 1 year ago

There's also a numerically stable one:

@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp_stable(p, out):
    n, m = p.shape
    assert len(out) == n
    assert out.ndim == 1
    assert p.ndim == 2

    for i in numba.prange(n):
        p_max = np.max(p[i])
        res = 0
        for j in range(m):
            res += np.exp(p[i, j] - p_max)
        res = np.log(res) + p_max
        out[i] = res

To whom want to use this function directly, you have to decide whether the parallel should be set to True according to your case.

As it can be found from the brandonwillard's result (this one with parallel and this one w/o parallel) that when the size of array is relatively small, parallel will actually impair the performance.

I directly used the function above and found that set parallel=True actually slowed down the program, as the input array was small, with the size of (1, 10).