JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
459 stars 53 forks source link

feat: vector-valued GPs #88

Closed emilemathieu closed 2 months ago

emilemathieu commented 2 years ago

Hi all,

First thanks for this nice library that I've recently starting using! ):)

I'd be interested in vector-valued GPs and from what I understand, this is not supported yet right? Or am I missing something? I've passed a kernel function which return matrices of shape [d x d].

gp = gpx.Prior(kernel=RBFCurlFree())
gp = gp(dict(kernel=kernel.params, mean_function={}))
y = gp(x)

I believe that the way to deal with this is to rearrange/reshape things as [N*d, N*d] but I don't really have much experience with how to easily deal with vector-valued GPs .

Best, Emile

daniel-dodd commented 2 years ago

Hey Emile,

Thank you for your question and interest in GPJax. GPJax does not currently support this, but it is on our radar! We (me and @thomaspinder) are currently working on a multi-output GP package as a seamless extension of GPJax, to support functionality of this kind. We have decided to make this separate package (currently called "MOGPJax") to keep GPJax as a light readable codebase while offering users greater flexibility in defining a broad scope of vector-valued models and scalable inference procedures that, in general, differ from single-output models. We are currently in the early stages of development but expect our first release to be made public soon (once we are happy with the core structure).

Cheers, Dan

emilemathieu commented 2 years ago

Thanks @Daniel-Dodd for your answer! Would you by any chance be able to tell a bit about the core idea in how to get a seamless extension? By rearranging/reshaping the mean [b, n, d] -> [b, n*d] and covariance [b, n, n, d, d] -> [b, n*d, n*d]?

emilemathieu commented 2 years ago

@Daniel-Dodd would you have any update regarding this MOGPJax package by any chance?

daniel-dodd commented 2 years ago

Hi @emilemathieu,

Apologies for my delay.

We are actively developing this package and working towards the first release. It will, however, take us more time. Some of this depends on us completing GPJax's v0.5 release. The first public release of MOGPJax will have (at the bare minimum): GPLVMs, isotopic conjugate GPs, and isotopic non-conjugate GPs for which the user can do map estimates or MCMC inference like in GPJax. Presently, we have some rough implementations, but these need further work.

Thanks, Dan

thomascerbelaud commented 1 year ago

Hi @Daniel-Dodd and @thomaspinder,

First thanks for this awesome library ! In my research, I am also very interested in using GP (more specifically SVGP) with multi outputs, and was wondering if, waiting for MOGPJax to be officially released, there was maybe some repo I could fork to begin to use multi output GP ? I could not find MOGPJax anywhere so I guess it is not public yet.

Also do you have a release date yet ?

Thanks a lot, Thomas

daniel-dodd commented 1 year ago

Hi @thomascerbelaud, Thank you for your kind words and interest.

Progress on this has fallen behind, due to major refactoring work on GPJax versions v0.5 - v0.5.2, and the repository has not been updated in while!

To get the ball rolling on this, and now that we have the JaxGaussianProcesses organisation established, I will aim to make this repository public soon (hopefully over this weekend, but certainly by the end of next week) once the tests pass (some things broke since GPJax refactoring)! That way we can develop in public, so that anyone can fork and contribute.

Currently,MOGPJax has a GPLVM via a map estimate, and we have a rough notebook implementation of a multi-output prior and conjugate posterior for isotopic datasets.

We plan to move out the kernels into a separate library JaxKern, and are interested in thinking about a multi-output kernel abstraction. In addition, we plan in the near future, to implement Kronecker linear operators over in JaxLinOp, so we can, e.g. think about linear coregionalisation model abstractions.

We would certainly be super interested in supporting a multi-output SVGP framework!

daniel-dodd commented 1 year ago

Hi @thomascerbelaud, this has been made public, JaxGaussianProcesses/MOGPJax.

Note this currently only has a GPLVM model. The JaxLinOp integration into GPJax has broken the multi-output prior and posterior stuff, particularly the gram/cross-covariance matrix construction, and I have not had time to fix this. I will open issues for all of these, shortly.

It would be great to work towards a clean multi-output kernel abstraction, with a neat way to compute cross-covariances, gram inverses, etc, that harmonises efficiency from abstractions in JaxLinOp. This would probably be a good starting point. Once we have this sorted, we'll have a good basis to start adding MOGP models to the library.

github-actions[bot] commented 2 months ago

There has been no recent activity on this issue. To keep our issues log clean, we remove old and inactive issues. Please update to the latest version of GPJax and check if that resolves the issue. Let us know if that works for you by leaving a comment. This issue is now marked as stale and will be closed if no further activity occurs. If you believe that this is incorrect, please comment. Thank you!

github-actions[bot] commented 2 months ago

There has been no activity on this PR for some time. Therefore, we will be automatically closing the PR if no new activity occurs within the next seven days. Thank you for your contributions.