ziatdinovmax / gpax

Gaussian Processes for Experimental Sciences
http://gpax.rtfd.io
MIT License
205 stars 27 forks source link

Is it possible to filter nans for ExactGP.predict inside UCB Acquisition functions calls? #36

Closed avivajpeyi closed 1 year ago

avivajpeyi commented 1 year ago

Sometimes the gpax.acquisition.UCB returns Nans. This is may be due to the ExactGP.predict returning nans for some samples?

See the below plot (UCB is not present as the array is full of nans)

Screen Shot 2023-08-23 at 3 34 26 pm

Changing the random seed for data generation:

Screen Shot 2023-08-23 at 3 35 38 pm
Code to reproduce nans

```python import gpax import numpy as np import matplotlib.pyplot as plt import numpyro gpax.utils.enable_x64() SEED = 1 N_OBS = 6 X_RANGE = (-2, 5) np.random.seed(SEED) def observations(x, noise_sigma=0.05): noise = np.random.normal(0, noise_sigma, len(x)) f = 1 / (x ** 2 + 1) * np.cos(np.pi * x) return f + noise def generate_data(): X_measured = np.random.uniform(*X_RANGE, N_OBS) X_unmeasured = np.linspace(*X_RANGE, 50) y_measured = observations(X_measured) y_true = observations(X_unmeasured, noise_sigma=0) return X_measured, y_measured, X_unmeasured, y_true def get_gp_preds(X_measured, y_measured, X_unmeasured): rng_key1, rng_key2 = gpax.utils.get_keys(SEED) noise_prior = numpyro.distributions.Normal(1) gp_model = gpax.ExactGP(1, kernel='RBF', noise_prior_dist=noise_prior) gp_model.fit(rng_key1, X_measured, y_measured) y_pred, y_sampled = gp_model.predict(rng_key2, X_unmeasured, noiseless=True) y_up = np.nanquantile(y_sampled, 0.95, axis=0).ravel() y_low = np.nanquantile(y_sampled, 0.05, axis=0).ravel() ucb_values = gpax.acquisition.UCB( rng_key2, gp_model, X_unmeasured, beta=4, maximize=False, noiseless=True) return y_pred, y_up, y_low, ucb_values def plot(X_measured, y_measured, X_unmeasured, y_true, y_pred, y_up, y_low, ucb_values): fig, ax = plt.subplots(1, 1, figsize=(4, 3)) ax.plot(X_unmeasured, y_true, lw=3, ls='--', c='k', label='True', alpha=0.1) ax.scatter(X_measured, y_measured, c='k', label="Observations") ax.plot(X_unmeasured, y_pred, lw=2, c='tab:orange', label='Model') ax.fill_between(X_unmeasured, y_low, y_up, color='tab:orange', alpha=0.3) ax2 = ax.twinx() ax2.plot(X_unmeasured, ucb_values, lw=1.5, color='tab:purple', alpha=0.9, zorder=-100) ax.plot([], [], lw=1.5, color='tab:purple', alpha=0.9, zorder=-100, label='UCB') ax.legend(frameon=True) ax.set_xlim(X_RANGE) ax2.set_yticks([]) fig.show() def main(): X_measured, y_measured, X_unmeasured, y_true = generate_data() y_pred, y_up, y_low, ucb_values = get_gp_preds( X_measured, y_measured, X_unmeasured ) plot( X_measured, y_measured, X_unmeasured, y_true, y_pred, y_up, y_low, ucb_values ) print(f"X_measured = {X_measured.tolist()}") print(f"y_measured = {y_measured.tolist()}") print(f"X_unmeasured = {X_unmeasured.tolist()}") if __name__ == '__main__': main() ``` Data: ``` X_measured = [0.9191540329180179, 3.042271454095107, -1.9991993762785858, 0.1163280084228786, -0.9727087642802088, -1.3536298366184154] y_measured = [-0.5510701474083296, -0.150299323805448, 0.24339790464416963, 0.8064144297202153, -0.42470373051379506, -0.19475225108675304] X_unmeasured = [-2.0, -1.8571428571428572, -1.7142857142857144, -1.5714285714285714, -1.4285714285714286, -1.2857142857142858, -1.1428571428571428, -1.0, -0.8571428571428572, -0.7142857142857144, -0.5714285714285716, -0.4285714285714286, -0.2857142857142858, -0.14285714285714302, 0.0, 0.1428571428571428, 0.2857142857142856, 0.4285714285714284, 0.5714285714285712, 0.714285714285714, 0.8571428571428568, 1.0, 1.1428571428571428, 1.2857142857142856, 1.4285714285714284, 1.5714285714285712, 1.714285714285714, 1.8571428571428568, 2.0, 2.1428571428571423, 2.2857142857142856, 2.428571428571428, 2.571428571428571, 2.7142857142857144, 2.8571428571428568, 3.0, 3.1428571428571423, 3.2857142857142856, 3.428571428571428, 3.571428571428571, 3.7142857142857135, 3.8571428571428568, 4.0, 4.142857142857142, 4.285714285714286, 4.428571428571428, 4.571428571428571, 4.7142857142857135, 4.857142857142857, 5.0] ```

ziatdinovmax commented 1 year ago

@avivajpeyi I think this is because your noise prior, Normal(1), allows negative values, but noise cannot be negative. If you change it to HalfNormal(1), everything works. You can visualize distributions using gpax.utils.dviz(numpyro.distributions.Normal(1)) ->
image

avivajpeyi commented 1 year ago

Ah, of course!! Thanks again @ziatdinovmax :)