Closed jecampagne closed 2 years ago
Yes - I want to add an option for custom kernels in the near future. We can also add more kernels to the kernels.py.
Do you have any specific kernel in mind?
Noise should be a float - thanks for catching it up!
Hi @ziatdinovmax I was wandering if it may be more productive to let the kernel cooked by the user as for instance here is a simple squarred exponentiel kernel.
Also the noise is in the hand of the user
# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise=0.1, jitter=1.0e-6, include_noise=True):
deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
k = var * jnp.exp(-0.5 * deltaXsq)
if include_noise:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
and then pass the kernel
function to the class ExactGP
class.
In passing I also do not manage to get the noise_prior
working. I have tried to follow the kernel_prior
example but it crashes.
Have a nice day.
Hi, here is class that is largely inspired from yours.
Notice that adding the dense_matrix option to the NUTS is very valuable....
class GaussProc:
"""
Gaussian process class
Args:
kernel: GP kernel
mean_fn: optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior: optional custom priors over kernel hyperparameters
mean_fn_prior: optional priors over mean function parameters
noise_prior: optional custom prior for observation noise
"""
def __init__(self, kernel: Callable[[jnp.ndarray,
jnp.ndarray,
Dict[str, jnp.ndarray],
float, float],jnp.ndarray],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None
) -> None:
clear_cache()
self.kernel = kernel
self.mean_fn = mean_fn
self.kernel_prior = kernel_prior
self.mean_fn_prior = mean_fn_prior
self.noise_prior = noise_prior
self.mcmc = None
def model(self, X:jnp.ndarray , y: jnp.ndarray):
"""GP probabilistic model"""
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
# Sample kernel parameters
kernel_params = self.kernel_prior()
noise = self.noise_prior()
# Add mean function (if any)
if self.mean_fn is not None:
args = [X]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute GP K(X,X)
K = self.kernel(
X, X,
kernel_params,
noise
)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=K),
obs=y,
)
def fit(self, rng_key:jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_warmup: int = 1_000,
num_samples: int = 1_000,
num_chains: int = 1,
chain_method: str = 'vectorized',
dense_mass: bool = True,
progress_bar: bool = True,
print_summary: bool = True
) -> None:
"""
Fit GP Kernel parameters using MCMC NUTS
Args:
rng_key: random number generator key
X: 2D 'feature vector' with :math:`n x num_features`
y: 1D 'target vector' with :math:`(n,)` dimensions
num_warmup: number of MCMC warmup states
num_samples: number of MCMC samples
num_chains: number of MCMC chains
chain_method: 'sequential', 'parallel' or 'vectorized'
dense_mass: diagonal HMC mass matrix or full dense (optimized during warmup)
progress_bar: show progress bar
print_summary: print summary at the end of sampling
"""
init_strategy = init_to_median(num_samples=100)
kernel_nuts = NUTS(self.model, init_strategy=init_strategy, dense_mass=dense_mass)
self.mcmc = MCMC(
kernel_nuts,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
chain_method = chain_method,
progress_bar=progress_bar
)
self.mcmc.run(rng_key, X, y)
if print_summary:
self.mcmc.print_summary()
def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
"""Get posterior samples (after running the MCMC chains)"""
return self.mcmc.get_samples(group_by_chain=chain_dim)
def get_mvn_posterior(self,
rng_key:jnp.array,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and mean+srd) of multivariate normal posterior
for a single sample of GP hyperparameters
"""
y_residual = y_train
if self.mean_fn is not None:
args = [X_train, params] if self.mean_fn_prior else [X_train]
y_residual -= self.mean_fn(*args).squeeze()
# compute kernel matrices for train and test data
k_pp = self.kernel(X_new, X_new, params, noise)
k_pX = self.kernel(X_new, X_train, params, jitter=0.0)
k_XX = self.kernel(X_train, X_train, params, noise)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(cov), a_min=0.0)) * jax.random.normal(
rng_key, X_new.shape[:1])
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()
return mean, mean+sigma_noise
def predict(self, rng_key: jnp.ndarray,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
X_new = X_new if X_new.ndim > 1 else X_new[:, None]
if samples is None:
samples = self.get_samples(chain_dim=False)
num_samples=samples[list(samples.keys())[0]].shape[0]
# do prediction
vmap_args = (
jax.random.split(rng_key, num_samples),
samples,
samples["noise"]
)
means, predictions = vmap(
lambda rng_key, samples, noise: self.get_mvn_posterior(
rng_key, X_train, y_train, X_new, samples, noise)
)(*vmap_args)
return means, predictions
I use it like that: definition of all functions and priors
def mean_fn(x, params):
"""Power-law behavior before and after the transition"""
return jnp.piecewise(
x, [x < 0, x >= 0],
[lambda x: params["R0"]+params["v"]*x,
lambda x: params["R0"]+params["v"]*x - params["k"]*(1.-jnp.exp(-x/params["tau"]))
])
def mean_fn_prior():
# Sample model parameters
R0 = numpyro.sample("R0", numpyro.distributions.Uniform(0.1, 50))
v = numpyro.sample("v", numpyro.distributions.Uniform(0.1, 50))
k = numpyro.sample("k", numpyro.distributions.Uniform(0.1,50))
tau = numpyro.sample("tau", numpyro.distributions.Uniform(0.1, 50))
# Return sampled parameters as a dictionary
return { "R0": R0, "v":v, "k":k, "tau":tau}
def kernel_prior():
length = numpyro.sample("k_length", numpyro.distributions.Normal(0.83, 0.05))
scale = numpyro.sample("k_scale", numpyro.distributions.LogNormal(0, 1))
return {"k_length": length, "k_scale": scale}
def noise_prior():
noise = numpyro.sample("noise", dist.Normal(1.0,0.1 ))
return noise
Then instantiation of the class
gp = GaussProc(kernel=kernel_RBF,
kernel_prior=kernel_prior,
noise_prior=noise_prior,
mean_fn=mean_fn,
mean_fn_prior=mean_fn_prior)
Fit the mode
rng_key, rng_key_predict = jax.random.split(jax.random.PRNGKey(0))
gp.fit(rng_key,X,y, num_warmup=5_000, num_samples=1_000)
Get samples et check parameter fit with arviz. Then get the means and predictions. I do not resample of each parameeter set generated from the NUTS MCMC chain.
means, predictions = gp.predict(rng_key_predict, X_train=X, y_train=y, X_new=X_new)
Then I can plot something like:
Concerning the GP kernels here are some, but I do not address the case to add/multiply them
def kernel_RBF(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
r2 = square_scaled_distance(X, Z, params["k_length"])
k = params["k_scale"] * jnp.exp(-0.5 * r2)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
def kernel_Matern12(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=1/2 kernel
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
k = params["k_scale"] * jnp.exp(-r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
def kernel_Matern32(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=3/2 kernel
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
sqrt3_r = 3**0.5 * r
k = params["k_scale"] * (1.0 + sqrt3_r) * jnp.exp(-sqrt3_r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
def kernel_Matern52(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=5/2 kernel
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
sqrt5_r = 5**0.5 * r
k = params["k_scale"] * (1.0 + sqrt5_r + sqrt5_r**2 /3.0) * jnp.exp(-sqrt5_r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
Voila my 2-cents contrib. I do not know if it fits your plans.
Looks awesome! I found that it was intuitively easier for end-users (usually domain scientists) to select a kernel by passing a string with kernel name - hence the current setup. But I agree that it should be possible for advanced users to pass a callable. I'm currently working on extending the gp and dkl modules to vector-valued targets and once completed (hopefully this week), I will add get back to this issue.
Notice that the title of the picture is not correct as it is not a Vanilla Gaussian Process, but I think you have corrected by yourself.
I found that
k_pp = self.kernel(X_new, X_new, params, noise)
is not correct and one should use
k_pp = self.kernel(X_new, X_new, params, jitter=0)
to avoid the "noise" in the "test" extrapolation.
Yes, but then you'll have to add noise to your cov.diagonal() unless your function is noiseless ;-) , as in e.g. here. In the end, both approaches produce similar results, but our approach is 'cleaner':
The current implementation:
The implementation where we follow a standard procedure as in the link above and do not include noise into k_pp but add noise to cov.diagonal()
such that y_sampled = dist.Normal(y_mean, jnp.sqrt(K)).sample(rng_key, sample_shape=(n,))
where K = cov.diagonal() + noise
I should have clarified that the reason the results are 'similar' and not the 'same' is because we used a full covariance matrix for sampling and compared it to a solution that used only variance. So, if we follow standard practice and do not include noise into k_pp and then add it to the diagonal of a covariance matrix as
diag_elements = jnp.diag_indices_from(cov)
cov_new = jax.ops.index_update(cov, diag_elements, cov.diagonal() + noise)
we will get identical results with the current gpax implementation.
Well, I was refering to
I. Williams, Gaussian Processes for Machine Learning, the MIT Press, 2006,
page 16
The K(X,X) is free of noise measurements
(sigma
) as this is prediction, but is not free of "noise" as the scale of the kernel as the same semantic.
Notice that in the JAX gaussian_process_regression.py
there is no noise added for the (amp * cov_map(exp_quadratic, xtest)
term (here)
The logic here is that the new/unseen data is assumed to follow the same distribution as the training data. Hence, if we introduced a model noise for the training data, we also want to include that noise in prediction. I think we can introduce a keyword argument to the prediction method(s) that allows a user to select between noise-free and noisy predictions.
Well, when you practice Kalman filtering for instance up to a certain step ypu have absorbed (k-1)-points and you have all in hand to produce X_{k/k-1}
which is the extrapolation of the vector state X (according here to linear propogation given by a A-matrix), and C_{k/k-1}
the covariance matrix
now if you want to add a new measurement at step "k" then of course you update your state vector and covariance matrix with the information of this new measurement. But before considering this new measurement, you do not include any "noise" coming from a measure in the extrapolation mechanism.
This parenthesis/analogy is the same for the prediction of f_\ast
and its Cov matrix which are the equivalent of the X_{k/k-1}
and C_{k/k-1}
in the Kalman filtering case.
I would encourage you to add a keyword and a few lines to explain that (in better english). Best.
By the way have a look at Gpy they have introduced a predict_noiseless
:
def predict_noiseless(self, Xnew, full_cov=False, Y_metadata=None, kern=None):
"""
Convenience function to predict the underlying function of the GP (often
referred to as f) without adding the likelihood variance on the
prediction function.
This is most likely what you want to use for your predictions.
Thanks, @jecampagne - you made excellent points and I agree that we should have a method/option for noise-free prediction. Will add it shortly.
Hi, Is it possible that the user would provide his/her own kernel? For the moment if I understand correctly kernel='RBF', 'Periodic', or 'Matern'. But one may can use other kind of kernels and make some add/mult... composition.
Also, why for example in the following code
noise
is an integer?Thanks