Open MuellerSeb opened 4 years ago
Hello,
Since this still seems to be of interest, I wanted to give some input here. I'm a co-developer of the package chemotools
where I lately integrated pentapy
to speed up the solution of pentadiagonal matrices which are encountered during spline smoothing (PR).
pentapy
gave quite some speedup, but I also stumbled over the multiple right-hand sides problem.
Therefore, I forked pentapy
to play around a little and I've found that there are still improvements that can be made in terms of runtime and especially memory.
I wanted to optimize those first before going into multiple right-hand sides and even parallelization.
I've found the two things:
some of the vectors are used to store data that are never used/don't have to be used:
cdef double[:] result = np.zeros(mat_j) # <- `ze` is a freshly allocated vector, it can hold the `result`
cdef double[:] al = np.zeros(mat_j)
cdef double[:] be = np.zeros(mat_j)
cdef double[:] ze = np.zeros(mat_j)
cdef double[:] ga = np.zeros(mat_j) # <- never used for backwards read
cdef double[:] mu = np.zeros(mat_j) # <- never used for backwards read
ze
is a freshly allocated vector for holding the transformed right-hand side. It's common practice to overwrite this vector during backward/forward substitutions. It should be noted that using ze
to hold result
is saving ~1/6 of the additonal memory required by the algorithm (not considering the left-hand side matrix and the right-hand side vector themselves).ga
and mu
are used to store elements that are never truly acessed again:
ga[1] = mat_flat[3, 1] # <- `ga[i]` is never truly referenced again
mu[1] = mat_flat[2, 1] - al[0] * ga[1] # <- `mu[i]` is never truly referenced again
al[1] = (mat_flat[1, 1] - be[0] * ga[1]) / mu[1]
be[1] = mat_flat[0, 1] / mu[1]
ze[1] = (rhs[1] - ze[0] * ga[1]) / mu[1]
for i in range(2, mat_j-2):
ga[i] = mat_flat[3, i] - al[i-2] * mat_flat[4, i] # <- `ga[i]` is never truly referenced again
mu[i] = mat_flat[2, i] - be[i-2] * mat_flat[4, i] - al[i-1] * ga[i] # <- `mu[i]` is never truly referenced again
al[i] = (mat_flat[1, i] - be[i-1] * ga[i]) / mu[i]
be[i] = mat_flat[0, i] / mu[i]
ze[i] = (rhs[i] - ze[i-2] * mat_flat[4, i] - ze[i-1] * ga[i]) / mu[i]
One might now say that they are referenced right after, but that's why I added the truly in there (I will come to that in a moment). Again, not initialising them will lead to savings of 2/6 of the total memory. Together with result
, we can get rid off ~50% of the memory requirements 🤯 This means that running the algorithms two times in parallel will have roughly the same memory requirements as the current version. On top of that, the initialisation has a Python interaction and removing them will give a speedup for free.
all variables are written to and read again from the internally allocated vectors. While Cython has optimized indexing when compared to Python, a memory access to an Array in the RAM is never really or at least hard to get fast. This is where caching comes in and the pentadiagonal solvers have a bit of a random access order by design (that's naturally given from the original publication and nothing one can fix here unless there is a better algorithm).
In the pentapy
-implementation there are a lot of intermediate results being computed, stored in vectors, and read again right after. Since Array-accessing is quite expensive, it comes as no surprise that the algorithm spends a significant fraction of time in read-write-processes while it should mostly be doing numerics.
How can this be fixed?
First of all, result
can be replaced by ze
in a very straightforward way.
Then ga
and mu
can be dropped completely in favour for the intermediate variables ga_i
and mu_i
, so basically 2x n * 64
-bit vectors are replaced by 2x 64
-bit floats that can be accessed directly.
Now comes the heavy part of using intermediate results efficiently. All values stored in
al
,be
, andze
are frequently reused. Usually, x[i-2]
and x[i-1]
are involved in computing all the variables for the i
-th step where x
is a placeholder for the respective vectors. Therefore, the variables:
al_i
, al_i_minus_1
, al_i_plus_1
be_i
, be_i_minus_1
, be_i_plus_1
ze_i
, ze_i_minus_1
, ze_i_plus_1
are introduced (so in total 9 * 64
bit floats more). With their help, the vector accesses can be reduced to 1 write in the factorization and 1 read in the backward substitution.
In that go, we can also read from mat_flat[4, i]
once and use it over and over again (it was read from many times).
Doing so, I ended up with the following code for Solver 1 (I didn't yet go over Solver 2):
# cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True
import numpy as np
cimport numpy as np
from libc.stdint cimport int64_t
def penta_solver_new(double[::, ::1] mat_flat, double[:] rhs):
return np.asarray(c_penta_solver_new(mat_flat, rhs))
cdef double[:] c_penta_solver_new(double[::, ::1] mat_flat, double[:] rhs):
# Variable declarations
cdef int64_t mat_j = mat_flat.shape[1]
cdef double[:] al = np.zeros(mat_j)
cdef double[:] be = np.zeros(mat_j)
cdef double[:] ze = np.zeros(mat_j)
cdef int64_t i
cdef double mu_i, ga_i, e_i
cdef double al_i, al_i_minus_1, al_i_plus_1
cdef double be_i, be_i_minus_1, be_i_plus_1
cdef double ze_i, ze_i_minus_1, ze_i_plus_1
# Factorization
# First row
mu_i = mat_flat[2, 0]
al_i_minus_1 = mat_flat[1, 0] / mu_i
al[0] = al_i_minus_1
be_i_minus_1 = mat_flat[0, 0] / mu_i
be[0] = be_i_minus_1
ze_i_minus_1 = rhs[0] / mu_i
ze[0] = ze_i_minus_1
# Second row
ga_i = mat_flat[3, 1]
mu_i = mat_flat[2, 1] - al_i_minus_1 * ga_i
al_i = (mat_flat[1, 1] - be_i_minus_1 * ga_i) / mu_i
al[1] = al_i
be_i = mat_flat[0, 1] / mu_i
be[1] = be_i
ze_i = (rhs[1] - ze_i_minus_1 * ga_i) / mu_i
ze[1] = ze_i
# Central rows
for i in range(2, mat_j-2):
e_i = mat_flat[4, i]
ga_i = mat_flat[3, i] - al_i_minus_1 * e_i
mu_i = mat_flat[2, i] - be_i_minus_1 * e_i - al_i * ga_i
al_i_plus_1 = (mat_flat[1, i] - be_i * ga_i) / mu_i
al_i_minus_1 = al_i
al_i = al_i_plus_1
al[i] = al_i_plus_1
be_i_plus_1 = mat_flat[0, i] / mu_i
be_i_minus_1 = be_i
be_i = be_i_plus_1
be[i] = be_i_plus_1
ze_i_plus_1 = (rhs[i] - ze_i_minus_1 * e_i - ze_i * ga_i) / mu_i
ze_i_minus_1 = ze_i
ze_i = ze_i_plus_1
ze[i] = ze_i_plus_1
# Second to last row
e_i = mat_flat[4, mat_j-2]
ga_i = mat_flat[3, mat_j-2] - al_i_minus_1 * e_i
mu_i = mat_flat[2, mat_j-2] - be_i_minus_1 * e_i - al_i * ga_i
al_i_plus_1 = (mat_flat[1, mat_j-2] - be_i * ga_i) / mu_i
al[mat_j-2] = al_i_plus_1
ze_i_plus_1 = (rhs[mat_j-2] - ze_i_minus_1 * e_i - ze_i * ga_i) / mu_i
ze_i_minus_1 = ze_i
ze_i = ze_i_plus_1
ze[mat_j-2] = ze_i_plus_1
# Last Row
e_i = mat_flat[4, mat_j-1]
ga_i = mat_flat[3, mat_j-1] - al_i * e_i
mu_i = mat_flat[2, mat_j-1] - be_i * e_i - al_i_plus_1 * ga_i
ze_i_plus_1 = (rhs[mat_j-1] - ze_i_minus_1 * e_i - ze_i * ga_i) / mu_i
ze[mat_j-1] = ze_i_plus_1
# Backward substitution that overwrites ze with the result
ze_i -= al_i_plus_1 * ze_i_plus_1
ze[mat_j-2] = ze_i
for i in range(mat_j-3, -1, -1):
ze[i] -= al[i] * ze_i + be[i] * ze_i_plus_1
ze_i_plus_1 = ze_i
ze_i = ze[i]
return ze
Is all of this not only micro-optimization?
My first attempts ended up making everything slower and I almost dropped it. But after applying the mentioned optimizations, I found a consistent 20 ... 25% speedup with perfplot
(applied to the solver only, nothing around):
This is admittedly not a quantum leap when it comes to solving pentadiagonal systems, but it's also not nothing. These optimizations can easily save a few minutes in the applications encountered in chemotools
. Also, that means that 5 right hand sides can be solved in the same time that 4 take with the current version. Besides, this only considers the runtime, but not the memory savings.
How to proceed Would you be interested in a collaboration and me submitting a PR on this from the fork I made?
Thanks for Your time!
I wish a great start in the week!
Kind regards
Appendix - What did not do any good
Changing from C-contiguous to F-contiguous Arrays that would fit better into the solvers access pattern has only a minor effect, i.e.,
def penta_solver_new(double[::1, :] mat_flat, double[:] rhs): # <- explicit F-order-Array here
return np.asarray(c_penta_solver_new(mat_flat, rhs))
cdef double[:] c_penta_solver_new(double[::1, :] mat_flat, double[:] rhs): # <- explicit F-order-Array here
# Variable declarations
cdef int64_t mat_j = mat_flat.shape[1]
...
did not really improve anything:
Hey there!
This is awesome :tada:
I have to admit that I just straight forward implemented the algorithms from the referenced paper and didn't care to much about memory efficiency.
I am happy to review a PR from you! Could you update the second algorithm as well?
Cheers, Sebastian
Of course I can also update the second algorithm.
Do you have a timeline on this? Just for me to plan.
Would you actually be interested in porting the low-level implementation to Rust? That would make it more accessible and the Python interface is also relatively straightforward. With your CI to build wheels already in place, distribution would be managable. I think this repository is great, but the use of Cython unnecessarily limits its use to Python only 🤔
We did something similar with https://github.com/GeoStat-Framework/GSTools-Core/ in the past.
Since I recently created solver.pxd
file to be able to cimport the solver module in other cython code, I wouldn't drop the cython implementation.
There is already a mechanism to select the desired solving backend (like solve_banded, scipy sparse, umf_pack ...), so it could be an idea to create a rust implementation with python bindings and add it to the list of available solvers.
Whats your opinion here?
Of course I can also update the second algorithm.
Do you have a timeline on this? Just for me to plan.
There is no timeline from my side with this package. ATM just doing simple maintenance.
There is already a mechanism to select the desired solving backend (like solve_banded, scipy sparse, umf_pack ...), so it could be an idea to create a rust implementation with python bindings and add it to the list of available solvers.
Ah yeah, that really makes sense!
There is no timeline from my side with this package. ATM just doing simple maintenance.
I'll see what I can do this week with respect to improving algorithm II and also the multiple right-hand sides.
I wanted to provide a tiny update on what I did for Algorithm 1 to make it capable of handling multiple right-hand sides. After I looked at the algorithm for long enough, I figured that the factorization of the left-hand side matrix is independent of the transformation and solve of the right-hand side. This allows for splitting them up and saving the factorization results to apply it for multiple right-hand sides afterwards. An implementation can look like:
%%cython -a
# cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True
import numpy as np
cimport numpy as np
from libc.stdint cimport int64_t
def penta_solver_new(double[:, :] mat_flat, double[:] rhs):
return np.asarray(c_penta_solver_new(mat_flat, rhs))
cdef double[:] c_penta_solver_new(double[:, :] mat_flat, double[:] rhs):
"""
Solves the pentadiagonal system of equations ``Ax = b`` with the matrix ``A`` and
the right-hand side ``b`` by
- factorizing the matrix ``A`` into auxiliary coefficients and a unit upper
triangular matrix
- transforming the right-hand side into a vector ``zeta``
- solving the system of equations by backward substitution
"""
cdef int64_t mat_n_rows = mat_flat.shape[1]
cdef double[:] result = np.empty(shape=(mat_n_rows,))
cdef double[::, ::1] mat_factorized = np.empty(shape=(mat_n_rows - 1, 2))
cdef double[::, ::1] aux_coeffs = np.empty(shape=(mat_n_rows, 3))
c_penta_factorize_algo1(mat_flat, mat_n_rows, mat_factorized, aux_coeffs)
solve_penta_from_factorize_algo_1(mat_n_rows, mat_factorized, aux_coeffs, rhs, result)
return result
def penta_factorize_algo1(double[:, :] mat_flat):
"""
Test function for a factorization only. Can be ignored.
"""
cdef int64_t mat_n_rows = mat_flat.shape[1]
cdef double[::, ::1] mat_factorized = np.empty(shape=(mat_n_rows - 1, 2))
cdef double[::, ::1] aux_coeffs = np.empty(shape=(mat_n_rows, 3))
c_penta_factorize_algo1(mat_flat, mat_n_rows, mat_factorized, aux_coeffs)
cdef void c_penta_factorize_algo1(
double[:, :] mat_flat,
int64_t mat_n_rows,
double[::, ::1] mat_factorized,
double[::, ::1] aux_coeffs,
):
"""
Factorizes the pentadiagonal matrix ``A`` into
- auxiliary coefficients ``e``, ``mu`` and ``gamma`` for the transformation of the
right-hand side
- a unit upper triangular matrix with the main diagonals ``alpha`` and ``beta``
for the following backward substitution. Its unit main diagonal is implicit.
All coefficients are stored in the columns of the respective arrays.
"""
### Variable declarations ###
cdef int64_t iter_row
cdef double mu_i, ga_i, e_i
cdef double al_i, al_i_minus_1, al_i_plus_1
### Factorization ###
# First row
mu_i = mat_flat[2, 0]
al_i_minus_1 = mat_flat[1, 0] / mu_i
be_i_minus_1 = mat_flat[0, 0] / mu_i
aux_coeffs[0, 1] = mu_i
mat_factorized[0, 0] = al_i_minus_1
mat_factorized[0, 1] = be_i_minus_1
# Second row
ga_i = mat_flat[3, 1]
mu_i = mat_flat[2, 1] - al_i_minus_1 * ga_i
al_i = (mat_flat[1, 1] - be_i_minus_1 * ga_i) / mu_i
be_i = mat_flat[0, 1] / mu_i
aux_coeffs[1, 1] = mu_i
aux_coeffs[1, 2] = ga_i
mat_factorized[1, 0] = al_i
mat_factorized[1, 1] = be_i
# Central rows
for iter_row in range(2, mat_n_rows-2):
e_i = mat_flat[4, iter_row]
ga_i = mat_flat[3, iter_row] - al_i_minus_1 * e_i
mu_i = mat_flat[2, iter_row] - be_i_minus_1 * e_i - al_i * ga_i
al_i_plus_1 = (mat_flat[1, iter_row] - be_i * ga_i) / mu_i
al_i_minus_1 = al_i
al_i = al_i_plus_1
be_i_plus_1 = mat_flat[0, iter_row] / mu_i
be_i_minus_1 = be_i
be_i = be_i_plus_1
aux_coeffs[iter_row, 0] = e_i
aux_coeffs[iter_row, 1] = mu_i
aux_coeffs[iter_row, 2] = ga_i
mat_factorized[iter_row, 0] = al_i
mat_factorized[iter_row, 1] = be_i
# Second to last row
e_i = mat_flat[4, mat_n_rows-2]
ga_i = mat_flat[3, mat_n_rows-2] - al_i_minus_1 * e_i
mu_i = mat_flat[2, mat_n_rows-2] - be_i_minus_1 * e_i - al_i * ga_i
al_i_plus_1 = (mat_flat[1, mat_n_rows-2] - be_i * ga_i) / mu_i
aux_coeffs[mat_n_rows-2, 0] = e_i
aux_coeffs[mat_n_rows-2, 1] = mu_i
aux_coeffs[mat_n_rows-2, 2] = ga_i
mat_factorized[mat_n_rows-2, 0] = al_i_plus_1
mat_factorized[mat_n_rows-2, 1] = 0.0
# Last Row
e_i = mat_flat[4, mat_n_rows-1]
ga_i = mat_flat[3, mat_n_rows-1] - al_i * e_i
mu_i = mat_flat[2, mat_n_rows-1] - be_i * e_i - al_i_plus_1 * ga_i
aux_coeffs[mat_n_rows-1, 0] = e_i
aux_coeffs[mat_n_rows-1, 1] = mu_i
aux_coeffs[mat_n_rows-1, 2] = ga_i
return
cdef void solve_penta_from_factorize_algo_1(
int64_t mat_n_rows,
double[::, ::1] mat_factorized,
double[::, ::1] aux_coeffs,
double[::] rhs_single,
double[::] result_view,
):
"""
Solves the pentadiagonal system of equations ``Ax = b`` with the factorized
matrix ``A`` and the right-hand side ``b``.
It transforms the right-hand side first into the vector ``zeta`` and overwrites
this vector with the solution vector ``x``.
"""
### Variable declarations ###
cdef int64_t iter_row
cdef double ze_i, ze_i_minus_1, ze_i_plus_1
### Transformation ###
# first, the right-hand side is transformed into the vector ``zeta``
# First row
ze_i_minus_1 = rhs_single[0] / aux_coeffs[0, 1]
result_view[0] = ze_i_minus_1
# Second row
ze_i = (rhs_single[1] - ze_i_minus_1 * aux_coeffs[1, 2]) / aux_coeffs[1, 1]
result_view[1] = ze_i
# Central rows
for iter_row in range(2, mat_n_rows-2):
ze_i_plus_1 = (rhs_single[iter_row] - ze_i_minus_1 * aux_coeffs[iter_row, 0] - ze_i * aux_coeffs[iter_row, 2]) / aux_coeffs[iter_row, 1]
ze_i_minus_1 = ze_i
ze_i = ze_i_plus_1
result_view[iter_row] = ze_i_plus_1
# Second to last row
ze_i_plus_1 = (rhs_single[mat_n_rows-2] - ze_i_minus_1 * aux_coeffs[mat_n_rows-2, 0] - ze_i * aux_coeffs[mat_n_rows-2, 2]) / aux_coeffs[mat_n_rows-2, 1]
ze_i_minus_1 = ze_i
ze_i = ze_i_plus_1
result_view[mat_n_rows-2] = ze_i_plus_1
# Last row
ze_i_plus_1 = (rhs_single[mat_n_rows-1] - ze_i_minus_1 * aux_coeffs[mat_n_rows-1, 0] - ze_i * aux_coeffs[mat_n_rows-1, 2]) / aux_coeffs[mat_n_rows-1, 1]
result_view[mat_n_rows-1] = ze_i_plus_1
### Backward substitution ###
# The solution vector is calculated by backward substitution that overwrites the
# right-hand side vector with the solution vector
ze_i -= mat_factorized[mat_n_rows-2, 0] * ze_i_plus_1
result_view[mat_n_rows-2] = ze_i
for iter_row in range(mat_n_rows-3, -1, -1):
result_view[iter_row] -= mat_factorized[iter_row, 0] * ze_i + mat_factorized[iter_row, 1] * ze_i_plus_1
ze_i_plus_1 = ze_i
ze_i = result_view[iter_row]
return
What did that do to the runtime?
The algorithm is now as fast as the current implementation in pentapy
.
What does it matter? The straightforward approach of just running a full solve (factorization + transformation + triangular solve) for $n$ right-hand sides in a (potentially parallelized) loop would take
On the other hand, the split approach that saves the factorization and then simply applies it for the tranformations and triangular solves would require
I timed the factorization and the transformations and triangular solves individually and found the following:
So, both steps require roughly the same amount of time, i.e., $t{f}\approx t{tt}$. That means while the straightforward approach will take $n\cdot\left(t{f}+t{tt}\right)=2\cdot n\cdot t{tt}$ as $n$ approaches $\infty$. The split approach on the other hand will only take approximately $1\cdot n\cdot t{tt}$ while already outperforming the former for tiniest values of $n$.
In other words, the split approach will to a good approximation only require the time it takes for a standard unit tridiagonal transformation and backward substitution per 1 right-hand side while the factorization becomes negligible.
Extending the upper code to multiple right-hand sides is quite simple by iterating over the columns of rhs
after the factorization is done. I'll for now probably just do a plain loop and then we can see whether parallelization is an issue or not given that the solves are very very fast and even 1000 right-hand sides with 10000 variables should only take 100 ms when I extrapolate what we see here. Something that fast is hardly worth parallelizing given the parallelization overhead.
@MuellerSeb I would highly appreciate your input on the parallelization (and all the other points of course 😉). What are your/the users' typical usecases? For me, the worst case is solving roughly 3000 right-hand sides with 4000 variables each, so this will easily be done in less than a second. Of course, I can still try what prange
brings in terms of performance.
If we go in that direction, would you like to have an argument n_workers
in the Python interface as it's typical for parallelized SciPy-functions? Then, the user would have full control of how many cores are used.
@MuellerSeb
I hope you are doing well!
I wanted to ask how your timeline is on this issue and if you need me to provide further input. I don't want to build up any pressure, please take your time! I just need to make a design decision that could involve a multithreaded pentapy
or another solution and to make the decision an estimate would be very helpful. Thanks!
Hey there, sorry for the long delay. I am currently preparing my parental leave and there are a lot of issues on my table to solve before leaving.
Since this package only has 14 stars ATM, there are not that much users and you seem to be the first power user besides me. I mainly wrote this package for Anaflow where I solve a PDE in laplace space and since the matrix is changing for every time-step, I only solve for one RHS at a time.
I would really love to see the usage of prange since I made good experience with it. It mostly only means to exchange range
with prange
for the outer most or inner most loop. The nice thing about prange is, that it acts like a normal range when we are not linking against openmp.
OpenMP support should be disabled by default, since cibuildwheel would include the openmp runtime in the wheel which is bad practice since it could mess with other packages using openmp (on conda this is fine).
To do so, I would us an env-var that controls this. Have a look at the GSTools setup, where we implemented this and there we also used the https://github.com/astropy/extension-helpers/ package from astropy to find the correct compile flags to link against openmp.
prange comes with an argument num_threads
which is None
by default to use all available cores. This argument could be forwarded from the solve
routine with None as default. The high-level name is debatable.
Thanks for all the great work. The performance improvements look really good.
Cheers, Sebastian
Thanks for checking back! I already wish all the best for you and your family then 🎉🥳
Thanks for pointing out the OpenMP issues. I wasn't aware of that since I never packaged a Python package 🙏 I'll come up with some updates then and already incorporate the feedback on the PR.
I think you really saved my a** with the OpenMP issue that would have kicked me quite hard for another package I'm writing.
I incorporated all the changes where you provided input so far 👍 Have a great time! 🙃
At the moment only one dependent variable
b
is allowed to solveA @ x = b
forx
.In other algorithms, multiple dependent variables could be passed, e.g.: https://numpy.org/devdocs/reference/generated/numpy.linalg.solve.html
Since the solver is implemented in cython, this could be optimized by using prange.