Open cap-jmk opened 2 years ago
Hi, I think this may be a bit too specific to implement it as a default. There are different ways of understanding stability of a Markov state I would say, for instance you could
In that sense it might be better for each user to implement their own version of such relabeling. A more efficient variant of yours could for instance be implemented with permutation matrices. :slightly_smiling_face:
so...
msm = estimate_msm(data)
msm_sorted = deeptime.markov.msm.MarkovStateModel(sort_msm(msm.transition_matrix))
Maybe you are right, however, it felt like it belonged to the overall Markov modelling which is part of the deeptime package. I can for sure implement it from my side, however, it feels strange. Implementation in deeptime would also improve readability of dependencies or the code in general. I.e.
msm = estimate_msm(data, sorted=True)
I think the TSM gives a good estimate of the states and serves like a fingerprint for my case. PCCA+ seems like a good idea and could be helpful in some cases.
How would you do it with permutation matrices? The best we could go would be O(n), right? How would the memory consumption look like for permutation matrices? I remember some application from solving linear systems with these matrices. Would your solution be similar?
To my knowedge there is no canonical way of sorting Markov states, so I do not think it is a good idea to make this a True
/False
decision. What could be done is offer a relabeling function such as
msm_sorted = msm.relabel(np.argsort(np.diag(msm.transition_matrix)), inplace=False)
in your case. I do not have the capacity to implement this right now but am happy to give pointers and work on pull requests with you. There are multiple layers to this, though. In particular we'll have to be very careful this doesn't break any other parts of the library where there are assumptions on the Markov states staying the same over the course of taking submodels (for example when restricting yourself to the largest connected component in terms of jump probability connectivity graph). Also there are the following cases to keep in mind:
There are probably more things to keep in mind here. In any case I think the easiest for you is to really sort the matrix on your own and create a new MSM instance.
Regarding permutation matrices: Yes, we cannot get better than O(n), but we can achieve vectorization.
Yes, I know what you mean. I will give my best to support you.
Cool, thanks! :rocket: I think a good first step would be reordering count matrices (in TransitionCountModel
). Do you want to have a stab at that? I am still not entirely sure what such a method should be called, as it's not really a sorting but rather a relabeling - in general at least. Perhaps permute
? Or reorder
? First I thought transpose
might be a good fit but that is really more used in the context of axes.
@clonker, yes i could give it a try. Where do you want to change something? I would sort it in deeptime/markov/_transition_counting.py
yes that would be a good start!
Nice. Okay, I got a working sorting algorithm implemented. However, I would love that you review it before I start implementing it in deeptime
. I don't know why, but I could only make it work with bubble sort on the diagonal.
def sort_markov_matrix(markov_matrix):
"""Takes in random markov matrix
returns sorted markov matrix
Args:
markov_matrix (np.array): unsorted matrix
Returns:
(np.array): sorted Markov matrix
"""
diag = np.diag(markov_matrix)
sorting = np.argsort(diag)
for i in range(len(diag)):
for j in range(len(diag) - 1):
if diag[j + 1] > diag[j]:
markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
return markov_matrix
So here is a version for dense matrices, ideally we would support both dense and sparse though:
import numpy as np
from deeptime.markov.msm import MarkovStateModel
P = np.random.uniform(0, 1, size=(5, 5))
P /= P.sum(1)[:, None]
msm = MarkovStateModel(P)
diag = np.diag(msm.transition_matrix)
sorting = np.argsort(diag)[::-1]
perm = np.eye(len(sorting), dtype=msm.transition_matrix.dtype)[sorting]
msm_reordered = MarkovStateModel(np.linalg.multi_dot((perm, msm.transition_matrix, perm.T)))
I see what you meant. With the multi-dot, you would always do Θ(2n) operations, whereas if you implement the sorting manually, you would do O(sqrt(n)) operations. Or am I overlooking something?
~Nope not overlooking anything.~ While it probably warrants a benchmark, I would imagine that multi dot outperforms manual sorting in Python though. Things are different if you implement the sorting in an extension.
Edit: Actually matrix multiplications are (naively) Θ(n^3). In any case, here we can see that complexity =/= efficiency. 🙂
Totally enlightening. Just, as you have the benchmark already written, I would be interested how it goes when we go beyond 10k samples. It's where things get messy usually.
To satisfy your curiosity:
Now what would be interesting is the scaling behavior against a c/c++ coded sorting extension and against sparse matrices. Estimating a dense transition matrix with 10k Markov states is a tough task anyways because of the massive amounts of data you'd need.
Okay, I see. Maybe putting the loop into @njit() could help? It don't see why it should be slower. Re indexing should be faster than multiplying loads of elements, I guess. Code:
from numba import njit
@njit(parallel=True)
def sort_markov_matrix(markov_matrix):
"""Takes in random markov matrix
returns sorted markov matrix
Args:
markov_matrix (np.array): unsorted matrix
Returns:
(np.array): sorted Markov matrix
"""
diag = np.diag(markov_matrix)
sorting = np.argsort(diag)
for i in range(len(diag)):
for j in range(len(diag) - 1):
if diag[j + 1] > diag[j]:
markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
return markov_matrix
njit
doesn't work for this function on my machine, also I don't really want to pull another dependency into deeptime. If you want to put together a python-bound c++ implementation then i'm happy to benchmark it, though. The jit
performance is comparable to the python loop.
In any case I think the vectorized permutation matrix implementation is a good middle ground between a lot of implementation work and harder to maintain code (c++ extension) vs. easy to write and maintain but poor performance (python loop).
Good point. I will be happy to provide a c++ implementation. However, I am not sure how to link it. Do you have any resources on it? Then, let's do it with your implementation?
Hi, any progress on this?
Yes just learned some more C++ and uni politics and would have more time from now to work on it :)
Cool, let me know if you need pointers / help!
Is your feature request related to a problem? Please describe. I am doing Markov modelling for SAR/QSAR analysis of chemical compounds and would need sorted markov matrices.
I suggest to sort the Markov matrix according to the most stable state. Something like with better memory management:
Test with
What do you think?