LBL-EESA / fastkde

Other
50 stars 10 forks source link

PDF values at unseen points #8

Closed reisner closed 2 years ago

reisner commented 2 years ago

Hi there,

Is there a way, with this package, to get PDF values at unseen points? Something like sklearn's score_samples function for their KDE library? Or do I have to do some sort of interpolation myself from the PDF returned by fastKDE?

Thanks!

taobrienlbl commented 2 years ago

Hi @reisner, yes, there is such a function: pdf_at_points()

Glad to hear that you might be finding fastKDE useful! Cheers--

reisner commented 2 years ago

Hi @taobrienlbl

I just want to clarify I'm using that function correctly. When I'm using pdf_at_points, is it just a single call, to that function? I was originally thinking it'd be a call to pdf to create the model and then pdf_at_points to identify the pdf at query points (analogous to the fit/predict workflow in scikitlearn). However looking at the docs, I think it's a single call.

Something like this:

import numpy as np
from fastkde import fastKDE

train_x = 50*np.random.normal(size=100) + 0.1
train_y = 0.01*np.random.normal(size=100) - 300

test_x = 50*np.random.normal(size=100) + 0.1
test_y = 0.01*np.random.normal(size=100) - 300

test_points = list(zip(test_x, test_y))
results = fastKDE.pdf_at_points(train_x, train_y, list_of_points = test_points)

Does that look correct?

taobrienlbl commented 2 years ago

Hi @reisner, yes, that's the correct usage of pdf_at_points(). For something more similar to the fit/predict workflow, you could use the object-oriented interface of fastKDE to first calculate the optimal PDF in spectral space (the training phase), followed by a call to evaluate the PDF at specific points (predict). The code for pdf_at_points() in fact just wraps that functionality, so you could mimic the approach there if you'd like; see https://github.com/LBL-EESA/fastkde/blob/2e01cbf06adc94843bf7cd3387cba80a929f1721/fastkde/fastKDE.py#L1306

reisner commented 2 years ago

OK thanks so much!

I think some of this would be valuable to include in the README / docs, unless it's somewhere else (I havent been able to find it).

taobrienlbl commented 2 years ago

I absolutely agree! If you happen to have bit of time to spare to provide a short pull request to augment the documentation, that would be a huge favor to other users of fastKDE. It may be a while before I can get a chance to go through and add to the documentation. Cheers--

reisner commented 2 years ago

OK! I added a PR: https://github.com/LBL-EESA/fastkde/pull/9

reisner commented 2 years ago

I've also noticed this approach is very slow (that's mentioned in the comment for the function as well). Is it possible to have the same speedup for this usecase?

taobrienlbl commented 2 years ago

Thanks so much!! Unfortunately, no it's not easy to have the same speedup. It's slow because fastkde.pdf_at_points() uses a direct Fourier transform for the final PDF estimation stage, since the points aren't on a regular grid; the speed of the DFT increases like O(N^2), where N is the number of points at which the PDF is estimated. fastkde.pdf() uses a fast Fourier transform for that same stage, which is O(N log N).

Technically, it should be possible to implement a variant on a non-uniform FFT, but it's a different type of nuFFT than is used in the forward transform part of fastKDE. The forward transform goes from an unstructured set of points to a structured set of points, whereas the version here would need to do the opposite. This ends up being a totally different algorithm. It's technically possible but would be a substantial amount of work to implement. So far you're only the 2nd person I'm aware of who has made use of pdf_at_points(), so this isn't something I've thought a lot about.

Of course now that I'm thinking about it, I'm seeing other ways we could get a speed-up here without having to write a bunch of new code. Threading and/or GPU calculations could be useful here. The DFT is embarrassingly parallel, so it should be simple to use cython.parallel.prange for the outer loop (see https://github.com/LBL-EESA/fastkde/blob/2e01cbf06adc94843bf7cd3387cba80a929f1721/fastkde/nufft.pyx#L603). Of course I say that this should be simple, but I usually find that threading is more complicated than I initially think it should be.

The dft_points() function could also be rewritten using numba, which can be used to parallelize on GPUs.