LLNL / MuyGPyS

A fast, pure python implementation of the MuyGPs Gaussian process realization and training algorithm.
Other
25 stars 11 forks source link

Refactor backend tests so that there are numpy and (jax/torch) versions of the `MuyGPS` objects #117

Closed bwpriest closed 1 year ago

bwpriest commented 1 year ago

The jax and torch correctness tests currently create singular MuyGPS objects and use them to create the objective functions for optimization. However, now that we are using HeteroscedasticNoise objects with nontrivial tensor internals, it matters that we create different kwargs like

cls.k_kwargs_heteroscedastic_n = {
    ...
    "eps": HeteroscedasticNoise(cls.eps_heteroscedastic_n),
}
cls.k_kwargs_heteroscedastic_j = {
    ...
    "eps": HeteroscedasticNoise(cls.eps_heteroscedastic_j),
}

and then create different MuyGPS objects like

cls.muygps_heteroscedastic_n = MuyGPS(**cls.k_kwargs_heteroscedastic_n)
cls.muygps_heteroscedastic_j = MuyGPS(**cls.k_kwargs_heteroscedastic_j)
bwpriest commented 1 year ago

I do not think that this is possible with how the backend is currently implemented. The current tests will have to do.