LLNL / MuyGPyS

A fast, pure python implementation of the MuyGPs Gaussian process realization and training algorithm.
Other
25 stars 11 forks source link

Anisotropic feature integrated into library. Tests added to gp.py and kernel.py #127

Closed alecmdunton closed 1 year ago

alecmdunton commented 1 year ago

One thing to discuss is how the length scales are fed into the AnisotropicDistortion model. I am currently defining these to be a dictionary with keys "length_scale0", "length_scale1", and so on. I am happy to change this if you think that a reworking would be better. I have also left adding tests to tests/backend for a later PR.

alecmdunton commented 1 year ago

First comment is that these tests fail in Python 3.8 and 3.9 due to my calling isinstance(distortion_fn, Union[AnisotropicDistortion, IsotropicDistortion]. Is this something I should just rewrite as isinstance(distortion_fn, AnisotropicDistortion) or isinstance(distortion_fn, IsotropicDistortion). ?

bwpriest commented 1 year ago

I found a bug introduced in the last PR relating to how the distortion model gets incorporated into KernelFn._fn, and consequently how the optimization function gets built. It is a holdover from when the distortion model had no parameters, so it made sense to fold it in the way we are currently doing it. Basically, we always use the value used to create IsotropicDistortion, even though the opt_fn gets new length_scale kwarg values. I should be able to fix the problem pretty quickly.

bwpriest commented 1 year ago

I found a bug

PR #128 fixed the bug that I found. Optimization appears to now be sensitive to length_scale, so it might be worth revisiting those notebooks once you get merged up.

alecmdunton commented 1 year ago

Isotropic is working now - debugging Anisotropic

alecmdunton commented 1 year ago

Still work to be done but getting there.

alecmdunton commented 1 year ago

@bwpriest there may be a problem with JAX and passing Hyperparameter objects as kwargs into the __call__ method for AnisotropicDistortion. We need to leave the length_scales kwargs in kwarg form until they are passed into __call__ so that the objective is sensitive to and model is updated with respect to, e.g., length_scale0, length_scale1, and so on. If we pass a kwarg like length_scale0=Hyperparameter(1.5) as an argument in the construction of an AnistropicDistortion object, we get thrown an error:

TypeError: Argument '<MuyGPyS.gp.kernels.hyperparameters.Hyperparameter object at 0x28662db40>' of type <class 'MuyGPyS.gp.kernels.hyperparameters.Hyperparameter'> is not a valid JAX type.

Unfortunately I think we have to leave this object as a Hyperparameter object until it is passed into __call__. One solution I have in mind is that I add a function to the MuyGPyS/gp/distortion/embed.py module which goes through the kwargs **length_scales one by one and evaluates Hyperparameter.__call__ on each length_scale* before it is passed into apply_anisotropic_distortion. What are your thoughts?

alecmdunton commented 1 year ago

I answered my own question.

alecmdunton commented 1 year ago

For some reason, the one test that keeps failing is this pesky crosswise_matern test in lines 674-687 of jax_correctness.

            allclose_gen(
                matern_gen_anisotropic_fn_n(
                    self.crosswise_diffs_n,
                    nu=self.nu,
                    length_scale0=self.length_scale,
                ),
                matern_gen_anisotropic_fn_j(
                    self.crosswise_diffs_j,
                    nu=self.nu,
                    length_scale0=self.length_scale,
                ),
            )
        )

The analogous pairwise_matern_test passes...