stefmolin / data-morph

Morph an input dataset of 2D points into select shapes, while preserving the summary statistics to a given number of decimal points through simulated annealing. It is intended to be used as a teaching tool to illustrate the importance of data visualization.
https://stefaniemolin.com/data-morph/
MIT License
60 stars 16 forks source link

Performance optimizations of various statistics #201

Open JCGoran opened 1 month ago

JCGoran commented 1 month ago

Is your feature request related to a problem? Please describe.

Not really a problem, more like a potential optimization (I haven't worked out the details to see if it actually works).

So, if I understood how the algorithm works under the hood, we basically move the points in the dataset, one at a time, then, at each iteration, compute the mean, the standard deviation, and the correlation coefficient of this new dataset. One thing that stands out performance-wise is that we currently use all of the points to compute the statistics at each step, which seems a bit wasteful.

Describe the solution you'd like

Instead of computing the statistics of the whole dataset, which requires at least iterating over all $n$ points (even more ops for the stdev/corrcoef), we can use the fact that we are only moving one point, and rewrite the new statistics in terms of old ones + a perturbation. For instance, for the new value of the mean statistic, we get:

$$ \langle X' \rangle = \langle X \rangle + \frac{\delta}{n} $$

where $\delta = x'_i - x_i$, and $n$ is the number of points in the dataset. Analogous formulas can be derived for the variance, which is the square of the stdev anyway (it's possible some tweaking of the denominators is needed when taking into account the Bessel correction):

$$ \text{Var}(X') = \text{Var}(X) + 2 \frac{\delta}{n}(x_i - \langle X \rangle) + \frac{\delta^2}{n} - \frac{\delta^2}{n^2} $$

and probably for the correlation coefficient (or better, its square) as well. This would allow us to compute all of the statistics in basically $O(1)$ time, instead of $O(n)$ or larger.

There's at least one problem which I haven't worked out yet: is this numerically stable? Since numerical accuracy is paramount for the code to work properly, if the above has a large loss of accuracy, then it's not very useful, but if it's stable, it could be worthwhile to explore implementing it.

Some references that could be of use (regarding both the computation and numerical stability):

Describe alternatives you've considered

None.

Additional context

None.

stefmolin commented 1 month ago

Make sure you account for the plan to add the median (see #181).

JCGoran commented 1 month ago

After some considerations, the median seems to be implementable as follows:

  1. sort the input data
  2. split the now sorted data into two AVL trees, one of size n // 2 (lower part), and the other of size n - n // 2 (higher part)
  3. after the above step, the median is then either min(higher part) (case n odd) or (max(lower part) + max(higher part)) / 2 (case n even)
  4. doing the replacement $x_i \mapsto x'_i$ is equivalent to removing $x_i$ from one of the two trees, followed by inserting $x'_i$ into one of the two trees (basically a bunch of if statements, depending on where we're removing/inserting)
  5. To maintain the balance of the trees (i.e. each tree has the same number of elements (case n even), or the higher part has 1 extra element (case n odd)), we occasionally need to either remove the largest element from the lower part and insert it in the higher part, or remove the smallest element from the higher part and insert it in the lower part
  6. after the rebalancing is done, go to step 3 to get the median

An AVL tree does all of the operations above in $O(\log n)$, so we can find the median of the "perturbed" dataset in $O(\log n)$ (numpy.median runs in $O(n)$ thanks to using something like quickselect, as opposed to the naive "sort and find" which runs in $O(n \log n)$ ).

I used the timeit module to see whether this actually works, and the results are encouraging (tried out on an array w/ 5M elements, 100 repeats):

Note that I'm not counting the initial sorting in the performance (which is $O(n \log n)$ ) since we only need to do it once, at the start of the simulation.

Maybe I'm overcomplicating things, but it seems to me that we need a tree-like structure for this; at first I've considered using the heapq builtin module to use the 2-heap algorithm, but it doesn't support removal of arbitrary elements, only the top-most one, so I've opted for using an AVL tree instead.