ziatdinovmax / gpax

Gaussian Processes for Experimental Sciences
http://gpax.rtfd.io
MIT License
205 stars 27 forks source link

Extension of kernel='RBF',... #7

Closed jecampagne closed 2 years ago

jecampagne commented 2 years ago

Hi, Is it possible that the user would provide his/her own kernel? For the moment if I understand correctly kernel='RBF', 'Periodic', or 'Matern'. But one may can use other kind of kernels and make some add/mult... composition.

Also, why for example in the following code

def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray,
                 params: Dict[str, jnp.ndarray],
                 noise: int = 0, **kwargs: float) -> jnp.ndarray:

noise is an integer?

Thanks

ziatdinovmax commented 2 years ago

Yes - I want to add an option for custom kernels in the near future. We can also add more kernels to the kernels.py. Do you have any specific kernel in mind? Noise should be a float - thanks for catching it up!

jecampagne commented 2 years ago

Hi @ziatdinovmax I was wandering if it may be more productive to let the kernel cooked by the user as for instance here is a simple squarred exponentiel kernel.

Also the noise is in the hand of the user

# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise=0.1, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k

and then pass the kernel function to the class ExactGP class.

In passing I also do not manage to get the noise_prior working. I have tried to follow the kernel_prior example but it crashes.

Have a nice day.

jecampagne commented 2 years ago

Hi, here is class that is largely inspired from yours.

Notice that adding the dense_matrix option to the NUTS is very valuable....

class GaussProc:
    """
    Gaussian process class
    Args:
        kernel: GP kernel 
        mean_fn: optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
        kernel_prior: optional custom priors over kernel hyperparameters 
        mean_fn_prior: optional priors over mean function parameters
        noise_prior: optional custom prior for observation noise
    """

    def __init__(self, kernel: Callable[[jnp.ndarray, 
                                         jnp.ndarray, 
                                         Dict[str, jnp.ndarray], 
                                         float, float],jnp.ndarray],
                 mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
                 kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
                 mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
                 noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None
                 ) -> None:
        clear_cache()
        self.kernel = kernel
        self.mean_fn = mean_fn
        self.kernel_prior = kernel_prior
        self.mean_fn_prior = mean_fn_prior
        self.noise_prior = noise_prior
        self.mcmc = None

    def model(self, X:jnp.ndarray , y: jnp.ndarray):
        """GP probabilistic model"""
        # Initialize mean function at zeros
        f_loc = jnp.zeros(X.shape[0])
        # Sample kernel parameters
        kernel_params = self.kernel_prior()
        noise = self.noise_prior()
        # Add mean function (if any)
        if self.mean_fn is not None:
            args = [X]
            if self.mean_fn_prior is not None:
                args += [self.mean_fn_prior()]
            f_loc += self.mean_fn(*args).squeeze()
        # compute GP K(X,X)
        K = self.kernel(
            X, X,
            kernel_params,
            noise
        )
        # sample y according to the standard Gaussian process formula
        numpyro.sample(
            "y",
            dist.MultivariateNormal(loc=f_loc, covariance_matrix=K),
            obs=y,
        )

    def fit(self, rng_key:jnp.array, X: jnp.ndarray, y: jnp.ndarray, 
                num_warmup: int = 1_000, 
                num_samples: int = 1_000,
                num_chains: int = 1, 
                chain_method: str = 'vectorized',
                dense_mass: bool = True,
                progress_bar: bool = True, 
                print_summary: bool = True
                ) -> None:

        """
        Fit GP Kernel parameters using MCMC NUTS

        Args:
            rng_key: random number generator key
            X: 2D 'feature vector' with :math:`n x num_features`
            y: 1D 'target vector' with :math:`(n,)` dimensions
            num_warmup: number of MCMC warmup states
            num_samples: number of MCMC samples
            num_chains: number of MCMC chains
            chain_method: 'sequential', 'parallel' or 'vectorized'
            dense_mass: diagonal HMC mass matrix or full dense (optimized during warmup)
            progress_bar: show progress bar
            print_summary: print summary at the end of sampling
        """

        init_strategy = init_to_median(num_samples=100)

        kernel_nuts = NUTS(self.model, init_strategy=init_strategy, dense_mass=dense_mass)
        self.mcmc = MCMC(
            kernel_nuts,
            num_warmup=num_warmup,
            num_samples=num_samples,
            num_chains=num_chains,
            chain_method = chain_method,
            progress_bar=progress_bar
        )
        self.mcmc.run(rng_key, X, y)
        if print_summary:
            self.mcmc.print_summary()

    def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
        """Get posterior samples (after running the MCMC chains)"""
        return self.mcmc.get_samples(group_by_chain=chain_dim)

    def get_mvn_posterior(self,
                rng_key:jnp.array, 
                X_train: jnp.ndarray, y_train: jnp.ndarray, 
                X_new: jnp.ndarray, 
                params: Dict[str, jnp.ndarray],
                noise: float = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Returns parameters (mean and mean+srd) of multivariate normal posterior
        for a single sample of GP hyperparameters
        """

        y_residual = y_train
        if self.mean_fn is not None:
            args = [X_train, params] if self.mean_fn_prior else [X_train]
            y_residual -= self.mean_fn(*args).squeeze()
        # compute kernel matrices for train and test data
        k_pp = self.kernel(X_new, X_new, params, noise)
        k_pX = self.kernel(X_new, X_train, params, jitter=0.0)
        k_XX = self.kernel(X_train, X_train, params, noise)
        # compute the predictive covariance and mean
        K_xx_inv = jnp.linalg.inv(k_XX)
        cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
        sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(cov), a_min=0.0)) * jax.random.normal(
        rng_key, X_new.shape[:1])

        mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
        if  self.mean_fn is not None:
            args = [X_new, params] if self.mean_fn_prior else [X_new]
            mean += self.mean_fn(*args).squeeze()

        return mean, mean+sigma_noise

    def predict(self, rng_key: jnp.ndarray, 
                X_train: jnp.ndarray, y_train: jnp.ndarray, 
                X_new: jnp.ndarray,
                samples: Optional[Dict[str, jnp.ndarray]] = None,
                ) -> Tuple[jnp.ndarray, jnp.ndarray]:

        X_new = X_new if X_new.ndim > 1 else X_new[:, None]

        if samples is None:
            samples = self.get_samples(chain_dim=False)

        num_samples=samples[list(samples.keys())[0]].shape[0]

        # do prediction
        vmap_args = (
            jax.random.split(rng_key, num_samples),
            samples,
            samples["noise"]
        )
        means, predictions = vmap(
            lambda rng_key, samples, noise: self.get_mvn_posterior(
                rng_key, X_train, y_train, X_new, samples, noise)
            )(*vmap_args)

        return means, predictions 

I use it like that: definition of all functions and priors

def mean_fn(x, params):
    """Power-law behavior before and after the transition"""
    return jnp.piecewise(
        x, [x < 0, x >= 0],
        [lambda x: params["R0"]+params["v"]*x, 
         lambda x: params["R0"]+params["v"]*x - params["k"]*(1.-jnp.exp(-x/params["tau"]))
        ])

def mean_fn_prior():
    # Sample model parameters
    R0 = numpyro.sample("R0", numpyro.distributions.Uniform(0.1, 50))
    v = numpyro.sample("v", numpyro.distributions.Uniform(0.1, 50))
    k = numpyro.sample("k", numpyro.distributions.Uniform(0.1,50))
    tau = numpyro.sample("tau", numpyro.distributions.Uniform(0.1, 50))
    # Return sampled parameters as a dictionary
    return { "R0": R0, "v":v, "k":k, "tau":tau}

def kernel_prior():
    length = numpyro.sample("k_length", numpyro.distributions.Normal(0.83, 0.05))
    scale = numpyro.sample("k_scale", numpyro.distributions.LogNormal(0, 1))
    return {"k_length": length, "k_scale": scale}

def noise_prior():
    noise = numpyro.sample("noise", dist.Normal(1.0,0.1 ))
    return noise

Then instantiation of the class

gp = GaussProc(kernel=kernel_RBF, 
                kernel_prior=kernel_prior, 
                noise_prior=noise_prior,
                mean_fn=mean_fn,
                mean_fn_prior=mean_fn_prior)

Fit the mode

rng_key, rng_key_predict = jax.random.split(jax.random.PRNGKey(0))

gp.fit(rng_key,X,y, num_warmup=5_000, num_samples=1_000)

Get samples et check parameter fit with arviz. Then get the means and predictions. I do not resample of each parameeter set generated from the NUTS MCMC chain.

means, predictions = gp.predict(rng_key_predict, X_train=X, y_train=y, X_new=X_new)

Then I can plot something like: image

Concerning the GP kernels here are some, but I do not address the case to add/multiply them

def kernel_RBF(X: jnp.ndarray, 
               Z: jnp.ndarray,  
               params: Dict[str, jnp.ndarray],
               noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
    r2 = square_scaled_distance(X, Z, params["k_length"])
    k = params["k_scale"] * jnp.exp(-0.5 * r2)
    if X.shape == Z.shape:
        k +=  (noise + jitter) * jnp.eye(X.shape[0])
    return k

def kernel_Matern12(X: jnp.ndarray, 
               Z: jnp.ndarray,  
               params: Dict[str, jnp.ndarray],
               noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
    """
    Matern nu=1/2 kernel
    """
    r2 = square_scaled_distance(X, Z, params["k_length"])
    r = _sqrt(r2)
    k = params["k_scale"] * jnp.exp(-r)
    if X.shape == Z.shape:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k

def kernel_Matern32(X: jnp.ndarray, 
               Z: jnp.ndarray,  
               params: Dict[str, jnp.ndarray],
               noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
    """
    Matern nu=3/2 kernel
    """
    r2 = square_scaled_distance(X, Z, params["k_length"])
    r = _sqrt(r2)
    sqrt3_r = 3**0.5 * r
    k = params["k_scale"] * (1.0 + sqrt3_r) * jnp.exp(-sqrt3_r)
    if X.shape == Z.shape:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k

def kernel_Matern52(X: jnp.ndarray, 
               Z: jnp.ndarray,  
               params: Dict[str, jnp.ndarray],
               noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
    """
    Matern nu=5/2 kernel
    """
    r2 = square_scaled_distance(X, Z, params["k_length"])
    r = _sqrt(r2)
    sqrt5_r = 5**0.5 * r
    k = params["k_scale"] * (1.0 + sqrt5_r + sqrt5_r**2 /3.0) * jnp.exp(-sqrt5_r)
    if X.shape == Z.shape:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k

Voila my 2-cents contrib. I do not know if it fits your plans.

ziatdinovmax commented 2 years ago

Looks awesome! I found that it was intuitively easier for end-users (usually domain scientists) to select a kernel by passing a string with kernel name - hence the current setup. But I agree that it should be possible for advanced users to pass a callable. I'm currently working on extending the gp and dkl modules to vector-valued targets and once completed (hopefully this week), I will add get back to this issue.

jecampagne commented 2 years ago

Notice that the title of the picture is not correct as it is not a Vanilla Gaussian Process, but I think you have corrected by yourself.

jecampagne commented 2 years ago

I found that

k_pp = self.kernel(X_new, X_new, params, noise)

is not correct and one should use

k_pp = self.kernel(X_new, X_new, params, jitter=0)

to avoid the "noise" in the "test" extrapolation.

ziatdinovmax commented 2 years ago

Yes, but then you'll have to add noise to your cov.diagonal() unless your function is noiseless ;-) , as in e.g. here. In the end, both approaches produce similar results, but our approach is 'cleaner':

The current implementation: GP_full_cov_posterior The implementation where we follow a standard procedure as in the link above and do not include noise into k_pp but add noise to cov.diagonal() such that y_sampled = dist.Normal(y_mean, jnp.sqrt(K)).sample(rng_key, sample_shape=(n,)) where K = cov.diagonal() + noise GP_diag_noise_posterior

ziatdinovmax commented 2 years ago

I should have clarified that the reason the results are 'similar' and not the 'same' is because we used a full covariance matrix for sampling and compared it to a solution that used only variance. So, if we follow standard practice and do not include noise into k_pp and then add it to the diagonal of a covariance matrix as

diag_elements = jnp.diag_indices_from(cov)
cov_new = jax.ops.index_update(cov, diag_elements, cov.diagonal() + noise)

we will get identical results with the current gpax implementation.

jecampagne commented 2 years ago

Well, I was refering to

I. Williams, Gaussian Processes for Machine Learning, the MIT Press, 2006,

page 16 image

The K(X,X) is free of noise measurements (sigma) as this is prediction, but is not free of "noise" as the scale of the kernel as the same semantic.

Notice that in the JAX gaussian_process_regression.py there is no noise added for the (amp * cov_map(exp_quadratic, xtest) term (here)

ziatdinovmax commented 2 years ago

The logic here is that the new/unseen data is assumed to follow the same distribution as the training data. Hence, if we introduced a model noise for the training data, we also want to include that noise in prediction. I think we can introduce a keyword argument to the prediction method(s) that allows a user to select between noise-free and noisy predictions.

jecampagne commented 2 years ago

Well, when you practice Kalman filtering for instance up to a certain step ypu have absorbed (k-1)-points and you have all in hand to produce X_{k/k-1} which is the extrapolation of the vector state X (according here to linear propogation given by a A-matrix), and C_{k/k-1} the covariance matrix

image

now if you want to add a new measurement at step "k" then of course you update your state vector and covariance matrix with the information of this new measurement. But before considering this new measurement, you do not include any "noise" coming from a measure in the extrapolation mechanism.

This parenthesis/analogy is the same for the prediction of f_\ast and its Cov matrix which are the equivalent of the X_{k/k-1} and C_{k/k-1} in the Kalman filtering case.

I would encourage you to add a keyword and a few lines to explain that (in better english). Best.

jecampagne commented 2 years ago

By the way have a look at Gpy they have introduced a predict_noiseless :


    def predict_noiseless(self,  Xnew, full_cov=False, Y_metadata=None, kern=None):
        """
        Convenience function to predict the underlying function of the GP (often
        referred to as f) without adding the likelihood variance on the
        prediction function.
        This is most likely what you want to use for your predictions.
ziatdinovmax commented 2 years ago

Thanks, @jecampagne - you made excellent points and I agree that we should have a method/option for noise-free prediction. Will add it shortly.