Closed alecmdunton closed 1 year ago
First comment is that these tests fail in Python 3.8 and 3.9 due to my calling
isinstance(distortion_fn, Union[AnisotropicDistortion, IsotropicDistortion]
. Is this something I should just rewrite as isinstance(distortion_fn, AnisotropicDistortion) or isinstance(distortion_fn, IsotropicDistortion)
. ?
I found a bug introduced in the last PR relating to how the distortion model gets incorporated into KernelFn._fn
, and consequently how the optimization function gets built. It is a holdover from when the distortion model had no parameters, so it made sense to fold it in the way we are currently doing it. Basically, we always use the value used to create IsotropicDistortion
, even though the opt_fn
gets new length_scale
kwarg values. I should be able to fix the problem pretty quickly.
I found a bug
PR #128 fixed the bug that I found. Optimization appears to now be sensitive to length_scale
, so it might be worth revisiting those notebooks once you get merged up.
Isotropic is working now - debugging Anisotropic
Still work to be done but getting there.
@bwpriest there may be a problem with JAX and passing Hyperparameter
objects as kwargs into the __call__
method for AnisotropicDistortion
. We need to leave the length_scales
kwargs in kwarg form until they are passed into __call__
so that the objective is sensitive to and model is updated with respect to, e.g., length_scale0
, length_scale1
, and so on. If we pass a kwarg like length_scale0=Hyperparameter(1.5)
as an argument in the construction of an AnistropicDistortion
object, we get thrown an error:
TypeError: Argument '<MuyGPyS.gp.kernels.hyperparameters.Hyperparameter object at 0x28662db40>' of type <class 'MuyGPyS.gp.kernels.hyperparameters.Hyperparameter'> is not a valid JAX type.
Unfortunately I think we have to leave this object as a Hyperparameter object until it is passed into __call__
. One solution I have in mind is that I add a function to the MuyGPyS/gp/distortion/embed.py
module which goes through the kwargs **length_scales
one by one and evaluates Hyperparameter.__call__
on each length_scale*
before it is passed into apply_anisotropic_distortion
. What are your thoughts?
I answered my own question.
For some reason, the one test that keeps failing is this pesky crosswise_matern test in lines 674-687 of jax_correctness.
allclose_gen(
matern_gen_anisotropic_fn_n(
self.crosswise_diffs_n,
nu=self.nu,
length_scale0=self.length_scale,
),
matern_gen_anisotropic_fn_j(
self.crosswise_diffs_j,
nu=self.nu,
length_scale0=self.length_scale,
),
)
)
The analogous pairwise_matern_test passes...
One thing to discuss is how the length scales are fed into the
AnisotropicDistortion
model. I am currently defining these to be a dictionary with keys"length_scale0"
,"length_scale1"
, and so on. I am happy to change this if you think that a reworking would be better. I have also left adding tests totests/backend
for a later PR.