geometric-kernels / GeometricKernels

Geometric kernels on manifolds, meshes and graphs
https://geometric-kernels.github.io/
Apache License 2.0
214 stars 18 forks source link

GPJax Frontend #77

Closed thomaspinder closed 1 year ago

thomaspinder commented 1 year ago

This PR introduces a frontend for GPJax. There is an accompanying notebook for this.

Similar to GPFlow, a frontend is established by creating a wrapper around the base kernel class in GPJax. Any model within GPJax will therefore be supported.

Associated with issue #74

thomaspinder commented 1 year ago

I can see that a test has failed for Python 3.8 and 3.9 as there's no function B.from_numpy(). I ran the notebook locally on 3.9 with no issues. Do you know if there's an easy fix for this?

stoprightthere commented 1 year ago

Hey @thomaspinder, thanks for your PR! I'm going to give a full review in a short time. For now, this from_numpy problem seems to stem from jax version being different or something like this. We will figure out how to fix it and update our main accordingly.

thomaspinder commented 1 year ago

Very nice PR!

Regarding jax failing, this is because jax introduced new Array type instead of DeviceArray since jax-0.4.1. I proposed a PR to lab to incorporate that change.

You may either wait till lab is updated to be up-to-date with jax-0.4.1 or cap jax version in requirements (so that it is older than 0.4.1)

Thanks! I've capped JAX for now.