neurodata / hyppo

Python package for multivariate hypothesis testing
https://hyppo.neurodata.io/
Other
215 stars 90 forks source link

Question for improvement #402

Closed quant12345 closed 9 months ago

quant12345 commented 11 months ago

Hi, everyone!

I would like to improve the function code: prim from the class: FriedmanRafsky.

Since the function has a decorator: @jit(nopython=True, cache=True) using numba, the first start is warming up and lasts a long time. For a 100 x 100 array, it turned out 30 - 120 times slower than my version on the first run. True, subsequent launches are about 5 times faster than my version.

In my code, I removed two loops and three conditional statements, replacing them with vector operations. before:

for i in range(V):
   if selected[i]:
       for j in range(V):
            if (not selected[j]) and weight_mat[i][j]:
            # not in selected and there is an edge
                if minimum > weight_mat[i][j]:
                     minimum = weight_mat[i][j]
                     x = i
                     y = j

after:

import numpy as np

rows_true = np.where(selected)[0]
columns_false = np.where(np.logical_not(selected))[0]
"""
based on choose rows with True and columns with False make indexes 
ind'  to fetch values from an array 'weight_mat'.
"""
ind = np.array(np.meshgrid(rows_true, columns_false, indexing='ij')).reshape(2, -1)
sample = weight_mat[ind[0], ind[1]]
"""
'i_min' index of the minimum from sample. With the help of which we
extract the desired index corresponding to the 'weight_mat' array.
 """
i_min = np.where((sample < minimum) & (sample == np.min(sample[np.nonzero(sample)])))

x = ind[0][i_min[0]][0]
y = ind[1][i_min[0]][0]

'Rafsky_with_numba' uses numba, 'Rafsky_no_numba' no. 'Rafsky_no_numba' with array size 512 x 512 faster by about 29 times. Both files generate 'weight_mat' arrays of size n x n, 'labels' is a one-dimensional array of length n.

'Rafsky_with_numba' prints the time taken to calculate each algorithm and the difference at each iteration.

'Rafsky_no_numba' uses the perfplot library to plot the performance of each function. The algorithms in both files check the result for equality.

The files are attached zipped(Rafsky.tar.gz).

What does the community think about this?

Rafsky.tar.gz

Rafsky

sampan501 commented 11 months ago

Adding this functionality sounds like a great idea! Can you make a PR with the changes, plot the wall times curve with the new code, and I can review it?

quant12345 commented 11 months ago

Adding this functionality sounds like a great idea! Can you make a PR with the changes, plot the wall times curve with the new code, and I can review it?

Yes, that's exactly what I want to do) But before that, I wanted to ask. Attached are two files. Will you watch them? My code with numba won't work. Need to decide what is more important the first quick start or subsequent ones with numba?

sampan501 commented 11 months ago

Well, I would add the one that works first, with a unit test that given the same set of data, the old a new algorithm give the same results. You can make another PR with the numba fixes.

I just want to make sure that when you make the changes the the prim function, nothing is broken.

quant12345 commented 11 months ago

Okay, I'll make a pull request. Only I can’t take tests with curves from two branches of the original and my own (maybe it’s possible, but I don’t know how).

sampan501 commented 11 months ago

I meant test for validity. Such that, given the the same dataset and seed, do you get the same test stat and p-value with the existing code and you're code.

For the wall times, you don't need unit tests. Just a figure like you already made is sufficient, just with the hyppo API compliant code

quant12345 commented 11 months ago

I meant test for validity. Such that, given the the same dataset and seed, do you get the same test stat and p-value with the existing code and you're code.

For the wall times, you don't need unit tests. Just a figure like you already made is sufficient, just with the hyppo API compliant code

Tests are down. I assume that the 'MST' function that calls 'prim' has a numba decorator. I tried locally without the decorator, the function worked without errors. I'll try to make the 'MST' function without loops so as not to use a decorator. If possible, show sample data: x, labels for the MST function.