Open arpanbiswas52 opened 1 year ago
Encountered the same issue without using Google Colab or GPU
@nmenon97 - did it happen in the 64-bit precision regime?
@ziatdinovmax yes! I can sometimes get it to work by changing num_warmup and num_samples as suggested above but it isn't a stable solution.
This is due to the peculiar behavior of jax.vmap
when approaching a memory limit. There are three ways to deal with this:
gpax.acquisition.qEI
, .qUCB
, etc; you can specify the batch size using the subsample_size
argument) and average them.gpax.acquisition.optimize_acq
to optimize it with num_initial_guesses << total_number_of_points.
**This is the issue encountered in Google Colab- under GPU setting T4 and High-RAM
When we run the function- ExactGP.fit()-- it produces NAN values for standard deviation calculation. The error can be repo in all the below modifications These things I have already tried
With current workaround it seems with reducing the number of total samples in MCMC setting to num_warmup=500, num_samples=500 (Default is num_warmup=1000, num_samples=3000), it is able to provide reasonable outputs.