JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
456 stars 53 forks source link

feat: Kronecker kernel + computation #371

Closed daniel-dodd closed 1 month ago

daniel-dodd commented 1 year ago

Following the integration of CoLA (#370). It would be great to add a Konecker Kernel + computation.

from gpjax.kernels import CombinationKernel, RBF

# Code inherits from CombinationKernel or ProductKernel
class Kronecker(CombinationKernel):
        kernels: Sequence[AbstractKernel] = None
        compute_engine: AbstractKernelComputation = static_field(KronckerKernelComputation())
        ...

       def __post_init__(self):
         ... # check  influence of the Kernel across each dimension 
             # or set of dimensions is separable (active dims)

# Demo
k1 = RBF(active_dims=0)
k2 = RBF(active_dims=1)
kron_kern = Kronecker([k1, k2])
github-actions[bot] commented 2 months ago

There has been no recent activity on this issue. To keep our issues log clean, we remove old and inactive issues. Please update to the latest version of GPJax and check if that resolves the issue. Let us know if that works for you by leaving a comment. This issue is now marked as stale and will be closed if no further activity occurs. If you believe that this is incorrect, please comment. Thank you!

github-actions[bot] commented 1 month ago

There has been no activity on this PR for some time. Therefore, we will be automatically closing the PR if no new activity occurs within the next seven days. Thank you for your contributions.