ziatdinovmax / gpax

Gaussian Processes for Experimental Sciences
http://gpax.rtfd.io
MIT License
211 stars 28 forks source link

MCMC prediction stability issue - providing NAN values of variance (Issue in Google Colab GPU/HighRam setting) #34

Open arpanbiswas52 opened 1 year ago

arpanbiswas52 commented 1 year ago

**This is the issue encountered in Google Colab- under GPU setting T4 and High-RAM

When we run the function- ExactGP.fit()-- it produces NAN values for standard deviation calculation. The error can be repo in all the below modifications These things I have already tried

With current workaround it seems with reducing the number of total samples in MCMC setting to num_warmup=500, num_samples=500 (Default is num_warmup=1000, num_samples=3000), it is able to provide reasonable outputs.

nmenon97 commented 11 months ago

Encountered the same issue without using Google Colab or GPU

ziatdinovmax commented 11 months ago

@nmenon97 - did it happen in the 64-bit precision regime?

nmenon97 commented 11 months ago

@ziatdinovmax yes! I can sometimes get it to work by changing num_warmup and num_samples as suggested above but it isn't a stable solution.

ziatdinovmax commented 10 months ago

This is due to the peculiar behavior of jax.vmap when approaching a memory limit. There are three ways to deal with this:

  1. Draw multiple random smaller batches of samples (see gpax.acquisition.qEI, .qUCB, etc; you can specify the batch size using the subsample_size argument) and average them.
  2. Assume that the acquisition function is continuous and use gpax.acquisition.optimize_acq to optimize it with num_initial_guesses << total_number_of_points.
  3. Use a device with a larger memory :)