JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
423 stars 51 forks source link

Variational Fourier Features #449

Open ahwillia opened 4 months ago

ahwillia commented 4 months ago

Thanks for the great package. I am interested in fitting GPs using the Variational Fourier Features framework described in Hensman et al. 2018. From quickly looking through the documentation, I don't believe it is currently implemented (sorry if I missed it). Does anyone know of any other implementations of this approach in jax?

If there is interest in including this I may be able to help contribute.

thomaspinder commented 4 months ago

Hey @ahwillia. Thanks for interest! We do not currently have a VFF implementation. However, there is certainly interest in the work from within the developers. If you have ideas around an imlpementation/contribution, then I'd be interested to hear your thoughts.

ahwillia commented 3 months ago

Thanks. I have been playing around with a few different things and developing some prototypes. Eventually I may have something that you'd be interested in merging.

I noticed that there is an RFF class that is used to draw approximate samples from the GP prior. An alternative use of this class would be to fit an approximate GP regression model, as proposed in the original Rahimi + Recht paper. Is this "weight space view" of GP regression something you plan to flesh out? I think what I have in mind are extensions of this basic approach.

houstonwarren commented 2 months ago

Hi all - +1 to the general idea of the sparse GP posterior/likelihood computations using the RFF kernel. This is something I'm working on myself and would love to hear if anyone had opinions as to the best way to implement given the current setup.