elixir-nx / scholar

Traditional machine learning on top of Nx
Apache License 2.0
411 stars 43 forks source link

Optimize median #208

Closed josevalim closed 8 months ago

josevalim commented 10 months ago

I found LazySelect (section 12.2.1). Here is a numpy implementation. Improving our median would have benefits for Affinity Propagation. Perhaps to contribute to Nx.

krstopro commented 10 months ago

Am I wrong or LazySelect works only with tensors with distinct elements?

josevalim commented 10 months ago

@krstopro i don't know actually. Do they mention this exclusively in the paper? The paper has already been purged from my brain. If that's the case, maybe we can call it uniq_median or distinct_median.

krstopro commented 10 months ago

@josevalim Every reference I see mentions distinct, e.g. here, here, and here (says ordered set, but that should be a sequence without duplicates).

josevalim commented 10 months ago

I see. So we either need to pick another algorithm or go with uniq_median. I am fine with uniq_median.

josevalim commented 10 months ago

More algorithms:

krstopro commented 10 months ago

@josevalim wrong issue?

Update: Sorry, I thought these deal with k-NN, but in fact they are proposing new ways to do selection and sorting which is relevant for this issue as well.

josevalim commented 10 months ago

I am thinking we should go ahead with truncated sort. Getting the median of N elements can be done by reshaping it into a matrix of shape {x, y}, and then doing the parallel sort on x and then doing a sort on the k-first columns of y. We could do z rounds, if desired.

msluszniak commented 9 months ago

Ok, so what I understand from the paper about truncated sort is that we want to compute a quick select procedure. Is that right? image

josevalim commented 9 months ago

Nx already implements an efficient sorting algorithm. The idea is that to compute the median for the whole tensor (and not on an axis), we can convert the tensor to shape {x, y} and then sort them in parallel (using regular Nx.sort on an axis), and then we compute the median, that's all. We wouldn't be implementing any new sorting algorithm in particular.

josevalim commented 9 months ago

We can do so benchmarks compared to our regular median implementation.

msluszniak commented 9 months ago

Nx already implements an efficient sorting algorithm. The idea is that to compute the median for the whole tensor (and not on an axis), we can convert the tensor to shape {x, y} and then sort them in parallel (using regular Nx.sort on an axis), and then we compute the median, that's all. We wouldn't be implementing any new sorting algorithm in particular.

Ok, I see. Definitely then, we need to benchmark the optimal values of x and y and how to deal with "rest" (Because for the tensor of shape that is a prime number (or has a small number of factors) then we will reshape to x*y + rest)

msluszniak commented 9 months ago

Because I thought about this problem to solve using parallel sorting https://cs.stackexchange.com/questions/87695/to-find-median-of-k-sorted-arrays-of-n-elements-each-in-less-than-onk-log

but I understand that for k smallest integer we want to split to {x, y} then sort rows, and sort first ceil(x/k) columns, and then repeat procedure p (what is sufficient p?) times, right?

msluszniak commented 9 months ago

I am thinking we should go ahead with truncated sort. Getting the median of N elements can be done by reshaping it into a matrix of shape {x, y}, and then doing the parallel sort on x and then doing a sort on the k-first columns of y. We could do z rounds, if desired.

I think that this approach won't work in general. There are some cases when after a lot of rounds of this procedure the median is still wrong.

josevalim commented 9 months ago

Shall we try lazy select then?

msluszniak commented 9 months ago

Yes, I think so :)

msluszniak commented 9 months ago

Do we have any alternative approaches to check?

josevalim commented 9 months ago

The only option left is the "Hierarchical partition" algorithm, section III.E, found at Efficient Selection Algorithm for Fast k-NN Search on GPUs.

msluszniak commented 9 months ago

Ok, I'll investigate this option

msluszniak commented 8 months ago

I checked the hierarchical partitioning from the paper and I think that it won't help us calculate the median efficiently. The algorithm is an attempt to implement a structure that efficiently fetches k-nearest neighbours when distances are precalculated. They propose an algorithm that runs in O(n) which is worse than min-heap O(k*log(n)), but it doesn't have a lot of irregular memory accesses like in min-heap. So we potentially can implement median as fetching the n/2 nearest elements to min/max element, but the algorithm doesn't scale well with number of neighbors: "... the performance improvement will increase with increasing N and decrease with increasing k. This is because when N is increasing more elements can be excluded. On the contrary, when k is increasing, elements will have more chances to be picked during top-down search."