jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.35k stars 589 forks source link

[Request] Complexity Information #1090

Closed RotatingGiraffe closed 5 months ago

RotatingGiraffe commented 5 months ago

Hi there, could you please add information regarding the runtime of your implementation in O-Notation? I've seen that in the scikit-learn documentation and find it quite useful. Specifically the complexity of estimating Markov-Chains.

For context: I am currently using your library for an university project (so thanks for providing this, very cool 😄 ) and am looking for information regarding the computational complexity of estimating Markov Chains. I am having trouble deriving it from your code and can't really find any general sources about it online (which is odd, maybe I am looking at the wrong places).

jmschrei commented 5 months ago

That's a good question. Unfortunately, right now I don't have time to go through and figure out/confirm the O notation for each algorithm. In practice, I find O notation to only be somewhat useful in the era of SIMD operations and GPUs. If you have a specific algorithm you want confirmation for, I can try to respond.

RotatingGiraffe commented 5 months ago

Hi, thanks for the quick answer! For my project I am aiming to work on an edge device, so I have little resources to work with, which is why I am interested in the complexity. But definetely interesting to hear that O-Notation is not that important sometimes, when lectures make it seem like a huge deal.

I am specifically interested in the complexity of what happens when I call MarkovChain.fit(). It seems to have something to do with distribution.summarize()? If it uses CategoricalConditional, would lines 248-253 of CategoricalConditional.py (see below) be where the data is actually learned?

  for j in range(self.d):
      strides = torch.tensor(self._xw_sum[j].stride(), device=X.device)
      X_ = torch.sum(X[:, :, j] * strides, dim=-1)

      self._xw_sum[j].view(-1).scatter_add_(0, X_, sample_weight[:,j])
      self._w_sum[j][:] = self._xw_sum[j].sum(dim=-1)

I see one loop, so O(n)? So basically counting how often specific sequences occur and then normalizing to get a probability? PyTorch code is kind of a black box to me so this is just a guess. Would the order of the chain be just a constant factor? Something like O(k*n)?

jmschrei commented 5 months ago

O notation is mostly valuable when talking about operations that are done sequentially. When you have operations that can be done massively in parallel (like, for instance, a matrix multiplication a GPU) then it starts to break down in terms of usefulness. So if you have an algorithm with a variant whose O notation is better but which replaces a massively parallel operation that can be done on a GPU with a sequential one that has to be done on a CPU, it might be faster in practice to use the GPU.

Regardless, you're right about the Markov chain. It's O(knl) where k is the length of the chain, n is the number of examples in your data set, and l is the length of those examples. nl can be basically thought of as the total number of k-mers in your data. Fitting the model just means counting the number of times each k-mer exists, as well as each k-1-mer exists, and then calculating fractions from that to get conditional probabilities.

RotatingGiraffe commented 5 months ago

Thank you very much for your help!