Becksteinlab / kda

Python package used for the analysis of biochemical kinetic diagrams.
GNU General Public License v3.0
3 stars 1 forks source link

ENH: Improve partial diagrams algo #116

Closed nawtrey closed 3 months ago

nawtrey commented 3 months ago

Changes

Description

While there is not a lot of performance to gain from changes to generate_partial_diagrams, I spent some time trying other algorithms to see how they compared to KDA.

Wang Algebra Algorithm

The KAPattern software uses Wang algebra to efficiently generate the state probability expressions, so I gave that a shot first.

Here is the function I wrote, which uses SymPy to generate the algebraic expressions:

```python def _gen_spanning_trees_wang(G, return_edges=False): # make an undirected `nx.Graph` version of the input `nx.MultiDiGraph` G_undirected = nx.Graph() G_undirected.add_edges_from(G.to_undirected(reciprocal=False, as_view=True).edges()) # assign unique edge labels wang_subs = {} for u, v, a in G_undirected.edges(data=True): # xy = yx for any x and y edge_label = Symbol(f"k{u}{v}", commutative=True) G_undirected[u][v]["label"] = edge_label # x + x = 0 for any x # x * x = 0 for any x wang_subs[edge_label + edge_label] = 0 wang_subs[edge_label * edge_label] = 0 # get N-1 nodes where the removed node is the highest degree node in `G` targets = [n for (n, degree) in sorted(G_undirected.degree(), key=lambda x: x[1])][:-1] expr_wang = 1 for i, t in enumerate(targets): expr = sum([G_undirected[t][n]["label"] for n in G_undirected.neighbors(t)]) expr_wang *= expr if i > 0: expr_wang = expr_wang.expand().subs(wang_subs) # delete G_undirected del G_undirected # multiply out the N-1 node expressions and apply # the Wang algebra assumptions via variable substitution # expr_wang = prod(exprs_node).expand().subs(wang_subs) # calculate the number of expected partial diagrams n_partials = enumerate_partial_diagrams(G) # preallocate arrays for storing partial diagrams edges/graphs # and initialize a counter base_nodes = G.nodes() if return_edges: partial_diagrams = np.empty((n_partials, G.number_of_nodes()-1, 2), dtype=np.int32) else: partial_diagrams = np.empty(n_partials, dtype=object) # iterate over the terms in the expression (each term # represents a unique spanning tree from `G`) for i, term in enumerate(expr_wang.args): # split the terms into individual rates # (e.g. ['k01', 'k03', 'k12']) term_list = str(term).split("*") # build the edge tuples from the rates # (e.g. [(0, 1), (0, 3), (1, 2)]) partial_edges = [(int(i), int(j)) for (_k, i, j) in term_list] if return_edges: partial_diagrams[i, :] = partial_edges else: G_partial = nx.Graph() G_partial.add_nodes_from(base_nodes) G_partial.add_edges_from(partial_edges) partial_diagrams[i] = G_partial return partial_diagrams ```

The code works perfectly well and all KDA tests pass. However, it is extremely slow. For EmrE it takes roughly 4.6 s to generate the spanning trees, whereas the KDA algo takes <20 ms. The issue is with the SymPy .expand() and .subs() methods, which have to "foil" multivariate polynomials which are incredibly complex, then perform variable substitutions. I tried to improve the performance by expanding/substituting as the expressions are being built (~33% faster), but it still was nowhere close to the performance of KDA. I believe SymPy is written in pure Python and is not known to provide great performance for tasks like this.

NetworkX Algorithm

Disappointed by the Wang algebra code, I figured I might as well try the NetworkX.SpanningTreeIterator.

Here is the code:

```python def _gen_spanning_trees_networkx(G, return_edges=False, key="val"): # make an undirected `nx.Graph` version of the input `nx.MultiDiGraph` G_undirected = nx.Graph() G_undirected.add_edges_from(G.to_undirected(reciprocal=False, as_view=True).edges()) n_partials = enumerate_partial_diagrams(G) partial_diagrams = np.empty(n_partials, dtype=object) i = 0 for partial_diagram in SpanningTreeIterator(G=G_undirected, weight=key, minimum=True): partial_diagrams[i] = partial_diagram i += 1 if return_edges: _partial_diagrams = np.empty((n_partials, G.number_of_nodes()-1, 2), dtype=np.int32) i = 0 for _partial_diagram in partial_diagrams: _partial_diagrams[i] = _partial_diagram.edges() i += 1 partial_diagrams = _partial_diagrams return partial_diagrams ```

Again, this passes all KDA tests and works perfectly well. In terms of code it is pretty straightforward since we hand off the spanning tree generation completely. However, again, it does not perform well compared to KDA. For the EmrE 8-state model I believe the spanning trees were generated in roughly 1 s, which is not terrible, but still considerably slower than the current KDA implementation.

KDA Updated Algorithm

This brings us to the changes here. I took another look at the current algorithm and couldn't find much room for improvement. I knew we were generating every combination of edges and filtering them so I took a look at the invalid edge cases. I discovered that many of the invalid edges did not include all nodes, so I found a fast way to reject these cases from the edges alone, before any diagrams are created. It turns out that for complex models this does show a noticeable improvement:

$ asv continuous master issue_22_improve_partial_diagram_ago -b time_generate_partial_*
Couldn't load asv.plugins._mamba_helpers because
No module named 'conda'
· Creating environments
· Discovering benchmarks
·· Uninstalling from conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
·· Installing 4131d72e <issue_22_improve_partial_diagram_ago> into conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
· Running 2 total benchmarks (2 commits * 1 environments * 1 benchmarks)
[ 0.00%] · For kda commit 190f3aeb <master> (round 1/2):
[ 0.00%] ·· Building for conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[ 0.00%] ·· Benchmarking conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[25.00%] ··· Running (bench_diagrams.PartialDiagrams.time_generate_partial_diagrams--).
[25.00%] · For kda commit 4131d72e <issue_22_improve_partial_diagram_ago> (round 1/2):
[25.00%] ·· Building for conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[25.00%] ·· Benchmarking conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[50.00%] ··· Running (bench_diagrams.PartialDiagrams.time_generate_partial_diagrams--).
[50.00%] · For kda commit 4131d72e <issue_22_improve_partial_diagram_ago> (round 2/2):
[50.00%] ·· Benchmarking conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[75.00%] ··· ...rtialDiagrams.time_generate_partial_diagrams                ok
[75.00%] ··· ============== ============= ============
             --                    return_edges
             -------------- --------------------------
                 graph           True        False
             ============== ============= ============
                3-state        621±20μs    291±100μs
              Hill-5-state    653±200μs    700±300μs
              Hill-8-state   2.50±0.07ms   2.67±0.2ms
              EmrE-8-state     15.6±2ms    14.8±0.8ms
                 Max-7         626±40ms     683±90ms
             ============== ============= ============

[75.00%] · For kda commit 190f3aeb <master> (round 2/2):
[75.00%] ·· Building for conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[75.00%] ·· Benchmarking conda-py3.9-pip+networkx-pip+numpy-pip+pytest-pip+sympy
[100.00%] ··· ...rtialDiagrams.time_generate_partial_diagrams                ok
[100.00%] ··· ============== ============= ============
              --                    return_edges
              -------------- --------------------------
                  graph           True        False
              ============== ============= ============
                 3-state        299±90μs     298±90μs
               Hill-5-state    507±100μs     524±20μs
               Hill-8-state    3.25±0.3ms   2.88±0.1ms
               EmrE-8-state   17.6±0.06ms    18.6±1ms
                  Max-7         1.13±0s     1.12±0.02s
              ============== ============= ============

| Change   | Before [190f3aeb] <master>   | After [4131d72e] <issue_22_improve_partial_diagram_ago>   |   Ratio | Benchmark (Parameter)                                                                |
|----------|------------------------------|-----------------------------------------------------------|---------|--------------------------------------------------------------------------------------|
| +        | 299±90μs                     | 621±20μs                                                  |    2.08 | bench_diagrams.PartialDiagrams.time_generate_partial_diagrams('3-state', True)       |
| -        | 18.6±1ms                     | 14.8±0.8ms                                                |    0.79 | bench_diagrams.PartialDiagrams.time_generate_partial_diagrams('EmrE-8-state', False) |
| -        | 1.12±0.02s                   | 683±90ms                                                  |    0.61 | bench_diagrams.PartialDiagrams.time_generate_partial_diagrams('Max-7', False)        |
| -        | 1.13±0s                      | 626±40ms                                                  |    0.56 | bench_diagrams.PartialDiagrams.time_generate_partial_diagrams('Max-7', True)         |

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.

On this specific run the 3-state model came up as a slower case, which is why it says PERFORMANCE DECREASED, but I'm not worried about the 300 microseconds we lost 😄

nawtrey commented 3 months ago

The change here is simple and a clear improvement, merging now.