Open twiecki opened 3 years ago
There's also a numerically stable one:
@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp_stable(p, out):
n, m = p.shape
assert len(out) == n
assert out.ndim == 1
assert p.ndim == 2
for i in numba.prange(n):
p_max = np.max(p[i])
res = 0
for j in range(m):
res += np.exp(p[i, j] - p_max)
res = np.log(res) + p_max
out[i] = res
For anyone interested in implementing this, we first need to consider whether or not distinct Op
s are required and, if they are, whether or not those new Op
s should be implemented as scalar Op
s that are broadcasted via Elemwise
. This is especially important given that these implementations appear to put strong restrictions on the number of dimensions of their inputs.
For instance, these reference implementations manually specify gradient computations, but that shouldn't be necessary. The same may be true for the manual loops in the Op.perform
implementations.
(NB: the parallelization features should be [made] configurable through the Numba backend, as well, so that really shouldn't be a limiting factor.)
Let's start by doing comparisons with pure Aesara implementations of log-sum-exp (e.g. PyMC3's and/or pymc3-hmm
's) that are compiled with the Numba backend.
In general, we want to make basic compositions of simple, pre-existing Op
s as performant as possible, and only create custom Op
s when it's absolutely necessary or dramatically more performant. In the latter case, such Op
s would ideally only need to be used by optimizations that replace the less performant sub-graphs composed of pre-existing Op
s (i.e. optimize automatically behind the scenes).
Tested the above, here are some preliminary results.
import timeit
import numba
import numpy as np
import aesara
import aesara.tensor as at
test_data = np.random.normal(size=(3, 3))
# method 1: custom Op
X = at.matrix("X")
y = logsumexp(X)
y_fn = aesara.function([X], y)
_ = y_fn(test_data)
%timeit _ = y_fn(test_data)
## 26.5 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
def logsumexp2(x, axis=None, keepdims=True):
# Adapted from https://github.com/Theano/Theano/issues/1563
x_max = at.max(x, axis=axis, keepdims=True)
x_max = at.switch(at.isinf(x_max), 0, x_max)
res = at.log(at.sum(at.exp(x - x_max), axis=axis, keepdims=True)) + x_max
return res if keepdims else res.squeeze()
# method 2: regular Op (numba-compiled)
y2 = logsumexp2(X, axis=1, keepdims=False)
y2_fn = aesara.function([X], y2, mode="NUMBA")
_ = y2_fn(test_data)
%timeit _ = y2_fn(test_data)
## 16 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It seems that the numba
-compiled Aesara-only Op
outperforms the custom Op
in this very test case.
Related: #465
The performance difference is probably due to the use of prange
in the custom Op
. If we increase the size of test_data
, I'm guessing we'll see the custom Op
start to outperform the Aesara-only Op
.
The performance difference is probably due to the use of
prange
in the customOp
. If we increase the size oftest_data
, I'm guessing we'll see the customOp
start to outperform the Aesara-onlyOp
.
Seems like it.
array size | runtime per loop - custom Numba Op | Numba-compiled Aesara graph | C/Python-compiled Aesara graph |
---|---|---|---|
3 x 3 | 26.5 µs ± 1.66 µs (10000 loops) | 16 µs ± 126 ns (100000 loops) | 9.88 µs ± 55.1 ns (100000 loops) |
25 x 25 | 31.1 µs ± 1.67 µs (10000 loops) | 28.4 µs ± 3.47 (10000 loops) | 16.7 µs ± 469 ns (100000 loops) |
100 x 100 | 41.7 µs ± 1.83 µs (10000 loops) | 130 µs ± 1.32 µs (10000 loops) | 112 µs ± 4.07 µs (10000 loops) |
1000 x 1000 | 1.21 ms ± 120 µs (1000 loops) | 20.4 ms ± 212 µs (10 loops) | 11.4 ms ± 418 µs (100 loops) |
10000 x 10000 | 99.7 ms ± 7.99 (10 loops) | 4.47 s ± 83.9 ms (1 loop) | 1.43 s ± 12.9 ms (1 loop) |
Well, this provides strong motivation for #409 (e.g. setting parallel=True
in our invocations of numba.njit
)! There might also be a few places where fastmath=True
can be set.
Aside from that, let's try to determine how/why Numba takes almost twice as long for small arrays (e.g. is it jumping into object mode for some reason, are we adding unnecessary np.[as]array
calls or making copies, etc.).
If anyone is interested, start by reviewing this Numba example and apply a similar approach to our Numba-compiled graphs.
NB: the first case might fall into the category of short-running functions with little room for significant improvement, but I kinda doubt that.
For anyone interested in running @fanshi118's comparison, here's a Gist that reproduces the results. Use timeit_data.to_markdown()
to produce a Markdown table that can be embedded in an issue.
Here's what I get:
data shape | custom Op |
Aesara graph (Numba) | Aesara graph (C) |
---|---|---|---|
(3, 3) | 21 µs ± 1.29 µs (100000) | 16.4 µs ± 216 ns (10000) | 10.4 µs ± 114 ns (100000) |
(25, 25) | 25.2 µs ± 637 ns (10000) | 27 µs ± 167 ns (10000) | 19.4 µs ± 845 ns (100000) |
(100, 100) | 64.2 µs ± 3.25 µs (10000) | 203 µs ± 21.2 µs (1000) | 175 µs ± 10.5 µs (1000) |
(1000, 1000) | 3.42 ms ± 168 µs (100) | 25.9 ms ± 1 ms (10) | 15.5 ms ± 419 µs (100) |
(10000, 10000) | 320 ms ± 14.2 ms (1) | 4.82 s ± 324 ms (1) | 1.61 s ± 28.4 ms (1) |
Also, for those who want to investigate the Numba-compiled function itself, here's how it can be obtained:
import inspect
fn, *_ = aesara_numba_fn.maker.linker.make_all()
cl_vars = inspect.getclosurevars(fn)
thunk = cl_vars.nonlocals['thunks'][0]
thunk_signature = inspect.signature(thunk)
fgraph_jit = thunk_signature.parameters['fgraph_jit'].default
With fgraph_jit
, one can perform standard low-level LLVM investigations:
from aesara.link.numba.dispatch import get_numba_type
numba_input_sig = tuple(get_numba_type(i.type) for i in aesara_numba_fn.maker.fgraph.inputs)
fgraph_jit.compile(numba_input_sig)
fgraph_jit.inspect_types()
```python
numba_funcified_fgraph (array(float64, 2d, A),)
--------------------------------------------------------------------------------
# File: /tmp/user/1000/tmpp3l2be04
# --- LINE 2 ---
def numba_funcified_fgraph(X):
# --- LINE 3 ---
# label 0
# X = arg(0, name=X) :: array(float64, 2d, A)
# $2load_global.0 = global(careduce_elemwise_maximum: CPUDispatcher(
Since functions like fgraph_jit
are compiled FunctionGraph
s, it's likely that one will need to inspect the functions called within the compiled function graph. Those can be obtained through similar means:
# Take a look at the Python function that was compiled by Numba
print(inspect.getsource(fgraph_jit.py_func))
# def numba_funcified_fgraph(X):
# auto_16690 = careduce_elemwise_maximum(X)
# auto_16638 = dimshuffle(auto_16690)
# auto_16777 = elemwise_numba_funcified_fgraph(auto_16638, auto_16670)
# auto_16643 = elemwise_subtract(X, auto_16777)
# auto_16658 = dimshuffle1(auto_16777)
# auto_16689 = careduce_elemwise_maximum1(auto_16643)
# auto_16673 = dimshuffle2(auto_16689)
# auto_16789 = elemwise_numba_funcified_fgraph1(auto_16643, auto_16673)
# auto_16681 = careduce_elemwise_add(auto_16789)
# auto_16804 = elemwise_numba_funcified_fgraph2(auto_16689, auto_16681, auto_16658)
# return (auto_16804,)
# Get the `careduce_elemwise_maximum` function called within `numba_funcified_fgraph`
careduce_elemwise_maximum = inspect.getclosurevars(fgraph_jit.py_func).globals[
"careduce_elemwise_maximum"
]
print(inspect.getsource(careduce_elemwise_maximum.py_func))
# def careduce_elemwise_maximum(X):
# axis_0_res = careduce_axes_fns[0](X)
# return axis_0_res
Tried turning parallel and fastmath off separately and got the following results: |
data shape | custom Op |
custom Op (fastmath only) |
custom Op (parallel only) |
Aesara graph (Numba) | Aesara graph (C) |
---|---|---|---|---|---|---|
(3, 3) | 25.1 µs ± 640 ns (10000) | 9.96 µs ± 131 ns (100000) | 34.5 µs ± 19.5 µs (1) | 15.1 µs ± 115 ns (10000) | 9.71 µs ± 53.1 ns (100000) | |
(25, 25) | 27.7 µs ± 712 ns (10000) | 14.2 µs ± 237 ns (10000) | 27.2 µs ± 321 ns (10000) | 24.1 µs ± 504 ns (10000) | 15.8 µs ± 138 ns (100000) | |
(100, 100) | 38.8 µs ± 386 ns (10000) | 67.5 µs ± 971 ns (10000) | 40.6 µs ± 2.2 µs (10000) | 128 µs ± 1.31 µs (10000) | 107 µs ± 1.8 µs (10000) | |
(1000, 1000) | 1.09 ms ± 21.2 µs (1000) | 5.73 ms ± 176 µs (100) | 1.35 ms ± 113 µs (1000) | 20.1 ms ± 660 µs (10) | 10.3 ms ± 283 µs (100) | |
(10000, 10000) | 114 ms ± 12.7 ms (10) | 632 ms ± 74.6 ms (1) | 133 ms ± 16.6 ms (10) | 4.32 s ± 19.7 ms (1) | 1.36 s ± 7.81 ms (1) |
@brandonwillard Do you think inlining will get the numbers closer?
@brandonwillard Do you think inlining will get the numbers closer?
I've already inlined most of the unnecessary function calls created by the backend (at least since #476). Given the way that the latency scales with shape, I'm guessing there are still a few more array copies being made somewhere (e.g. something isn't being in-placed).
Here's a Gist for running the LSE function in three separate steps (by isolating at.max
and at.switch
). Comparison between Numba
- and C
-compilation looks like the following
data shape | Aesara graph (Numba) - max | Aesara graph (C) - max | Aesara graph (Numba) - switch | Aesara graph (C) - switch | Aesara graph (Numba) - LSE | Aesara graph (C) - LSE |
---|---|---|---|---|---|---|
(3, 3) | 11.6 µs ± 142 ns (100000) | 7.13 µs ± 84.9 ns (100000) | 10.2 µs ± 27.8 ns (100000) | 6.77 µs ± 61.8 ns (100000) | 16.7 µs ± 252 ns (100000) | 11.6 µs ± 107 ns (100000) |
(25, 25) | 12.7 µs ± 101 ns (100000) | 7.72 µs ± 76.8 ns (100000) | 10.3 µs ± 113 ns (100000) | 6.67 µs ± 72.4 ns (100000) | 25.9 µs ± 432 ns (10000) | 17.5 µs ± 103 ns (100000) |
(100, 100) | 24.3 µs ± 499 ns (10000) | 18.2 µs ± 157 ns (100000) | 10.5 µs ± 146 ns (100000) | 6.79 µs ± 80.8 ns (100000) | 144 µs ± 799 ns (10000) | 104 µs ± 4.21 µs (10000) |
(1000, 1000) | 1.82 ms ± 20.9 µs (1000) | 862 µs ± 11.1 µs (1000) | 12.7 µs ± 240 ns (100000) | 7.85 µs ± 309 ns (100000) | 13.6 ms ± 89.4 µs (100) | 8.52 ms ± 38.7 µs (100) |
(10000, 10000) | 925 ms ± 4.62 ms (1) | 113 ms ± 494 µs (10) | 35.1 µs ± 373 ns (10000) | 14.1 µs ± 75.5 ns (100000) | 3.39 s ± 14.5 ms (1) | 1.3 s ± 5.65 ms (1) |
@fanshi118 Thanks, those are fascinating. I was hoping numba to be faster. Is this with fast_math and parallelization?
There appears to be some latency introduced by Aesara's Function
interface. For example, here's a simple Numba-only implementation of at.switch(at.isinf(x), 0, x)
(i.e. one of the subgraphs in the log-sum-exp graph) :
@numba.vectorize
def numba_switch_isinf(x):
if np.isinf(x):
return 0
else:
return x
If we compare this with the Aesara C-compiled function, the Aesara Numba-compiled function, and the Aesara-generated Numba function called directly (i.e. without Function
), we get the following:
data shape | Numba | Aesara-C | Aesara-Numba | Aesara-Numba (direct) |
---|---|---|---|---|
(3, 3) | 328 ns ± 1.72 ns (1000000) | 6.42 µs ± 76.1 ns (100000) | 9.97 µs ± 127 ns (100000) | 508 ns ± 2.92 ns (1000000) |
(100, 100) | 3.06 µs ± 21.1 ns (100000) | 18.7 µs ± 134 ns (100000) | 16.6 µs ± 251 ns (100000) | 3.75 µs ± 50 ns (100000) |
(1000, 1000) | 784 µs ± 39.6 µs (1000) | 1.16 ms ± 20.1 µs (1000) | 823 µs ± 12.8 µs (1000) | 789 µs ± 4.18 µs (1000) |
(5000, 5000) | 34.3 ms ± 610 µs (10) | 46.3 ms ± 1.2 ms (10) | 89.3 ms ± 3.4 ms (10) | 97 ms ± 3.12 ms (10) |
(10000, 10000) | 214 ms ± 11 ms (1) | 268 ms ± 7.68 ms (1) | 334 ms ± 1.73 ms (1) | 329 ms ± 2.14 ms (1) |
We might need to adjust our timeit
approach, because I'm not sure how trustworthy the larger shape
statistics are, but—as previously stated—it looks like the smaller shape
results are demonstrating a Function
-induced call overhead in the Aesara-C and Aesara-Numba results.
Regardless, someone should take a closer look at the Aesara-generated Numba function and see if there's anything that could be causing a drop in performance for larger shape
s.
It's worth noting that we really shouldn't focus on the fast_math
and/or parallel
options; those are simple things to add, and they don't address the persistent underlying sources of latency (i.e. the ones that will still be present when those options are enabled).
Instead, we need to focus on how well the Aesara-generated Numba functions perform compared to Aesara's C results and equivalent Numba-only functions.
In other words, we need to make Aesara's Numba-compiled functions at least as fast as the current C backend and—ideally—on par with a handwritten Numba function.
Here are the results of the log-sum-exp comparison after #498 (parallel
and fastmath
are off in all cases):
data shape | Numba | Aesara-C | Aesara-Numba | Aesara-Numba (direct) |
---|---|---|---|---|
(3, 3) | 635 ns ± 7.91 ns (1000000) | 11.8 µs ± 269 ns (100000) | 16.7 µs ± 459 ns (100000) | 3.53 µs ± 45.3 ns (100000) |
(25, 25) | 6.87 µs ± 55 ns (100000) | 22.1 µs ± 518 ns (10000) | 32.2 µs ± 529 ns (10000) | 17.8 µs ± 248 ns (100000) |
(100, 100) | 95.4 µs ± 1.34 µs (10000) | 160 µs ± 3.81 µs (10000) | 208 µs ± 4.03 µs (1000) | 202 µs ± 2.62 µs (10000) |
(1000, 1000) | 8.62 ms ± 51.8 µs (100) | 14.8 ms ± 363 µs (100) | 20.1 ms ± 219 µs (100) | 20.4 ms ± 700 µs (10) |
(10000, 10000) | 826 ms ± 11.4 ms (1) | 1.48 s ± 18.1 ms (1) | 4.04 s ± 43.9 ms (1) | 4.02 s ± 53.1 ms (1) |
Getting there. What do you think of where the other bottlenecks are?
We need to continue scrutinizing the generated Numba functions piece by piece.
My guess is that there's still an in-place operation missing and/or an erroneous array copy taking place within the generated functions.
@kc611 and I were just going over this issue (via #529) and, while developing a clearer MWE (see here), we noticed that the axis
argument—and the resulting transpose and loop in careduce_axis
—is likely the main source of decreased Numba performance compared to NumPy.
@brandonwillard Does that mean we can keep using vectorize
?
@brandonwillard Does that mean we can keep using
vectorize
?
Possibly, but all the things that @stuartarchibald and others mentioned here still apply; however, from our code-generation, it seems like we could at least be doing better about keeping arrays/operations "contiguous-friendly".
There's also a numerically stable one:
@numba.njit(parallel=True, fastmath=True) def numba_logsumexp_stable(p, out): n, m = p.shape assert len(out) == n assert out.ndim == 1 assert p.ndim == 2 for i in numba.prange(n): p_max = np.max(p[i]) res = 0 for j in range(m): res += np.exp(p[i, j] - p_max) res = np.log(res) + p_max out[i] = res
To whom want to use this function directly, you have to decide whether the parallel
should be set to True according to your case.
As it can be found from the brandonwillard's result (this one with parallel and this one w/o parallel) that when the size of array is relatively small, parallel will actually impair the performance.
I directly used the function above and found that set parallel=True
actually slowed down the program, as the input array was small, with the size of (1, 10).
@aseyboldt provided an implementation we can probably use here as a starting point: