Closed NotNANtoN closed 2 years ago
It looks like some acquisition function value calculated from Gaussian process model are NaN, this might be due to the numerical stability of GP model. Do you have a sample code that I can use to reproduce this crash?
Based on your design space, I assume you are tuning a XGBoost
model, I wrote the below code using hebo.sklearn_tuner.sklearn_tuner
to see if anything goes wrong, but it looks like everything was OK
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt
from hebo.optimizers.hebo import HEBO
from hebo.sklearn_tuner import sklearn_tuner
from hebo.design_space import DesignSpace
from xgboost import XGBRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.datasets import load_boston
space_cfg = [{'name': 'lr', 'type' : 'num', 'lb' : 0.00005, 'ub' : 0.1},
{'name': 'n_estimators', 'type' : 'int', 'lb' : 10, 'ub' : 200}, # multiplied by 10
{'name': 'max_depth', 'type' : 'int', 'lb' : 1, 'ub' : 10},
{'name': 'subsample', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
{'name': 'colsample_bytree', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
{'name': 'gamma', 'type' : 'num', 'lb' : 0.01, 'ub' : 10.0},
{'name': 'min_child_weight', 'type' : 'int', 'lb' : 1, 'ub' : 10},
{'name': 'fill_type', 'type' : 'cat', 'categories' : ['median', 'pat_median','pat_ema']},
{'name': 'flat_block_size', 'type' : 'int', 'lb' : 1, 'ub' : 1},
{'name': 'verbosity', 'type' : 'int', 'lb' : 0, 'ub' : 0}
]
X, y = load_boston(return_X_y = True)
cv = KFold(n_splits = 10, shuffle = True, random_state = 42)
result, report = sklearn_tuner(XGBRegressor, space_cfg, X, y, r2_score, cv = cv, max_iter = 64, report = True)
print(report)
report.metric.plot()
plot.show()
Hi, thanks a lot for your answer!
I tried your code and that seems to run fine. I am not using the sklearn_tuner, instead I use opt.suggest
manually. I modified the code with the boston dataset and that seems to work... But I noticed that it gets quite slow after about 50 steps - is that due to changed hyperparameters or does HEBO take much longer with more iterations?
As for my problem, it still occurs. Unfortunately my dataset is not publicly available. I'm basically running this loop, where the obj
function returns an r2_score:
for i in range(opt_steps):
rec = opt.suggest()
if "bs" in rec:
rec["bs"] = 2 ** rec["bs"]
if "n_estimators" in rec:
rec["n_estimators"] *= 10
print(i)
print(list(zip(rec.columns, rec.values[0])))
start_time = time.time()
opt.observe(rec, obj(df, cfg, rec))
print("Opt time: ", time.time() - start_time)
min_idx = np.argmin(opt.y)
print("Current score:", 1 - opt.y[-1][0])
print("Best score so far:", 1 - opt.y[min_idx][0])
print(f'After {i} iterations, best obj is {1 - opt.y[min_idx][0]:.4f} with params {opt.X.iloc[min_idx][0]}')
print()
This is my full output:
0
[('lr', 4.999999873689376e-05), ('n_estimators', 10), ('max_depth', 1), ('subsample', 0.5), ('colsample_bytree', 0.5), ('gamma', 0.009999999776482582), ('min_child_weight', 0.009999999776482582), ('fill_type', 'median')]
Opt time: 8.373520851135254
Current score: 0.06503021650669472
Best score so far: 0.06503021650669472
After 0 iterations, best obj is 0.0650 with params 4.999999873689376e-05
1
[('lr', 0.05002500116825104), ('n_estimators', 100), ('max_depth', 6), ('subsample', 0.7450000047683716), ('colsample_bytree', 0.7450000047683716), ('gamma', 2.504999876022339), ('min_child_weight', 2.504999876022339), ('fill_type', 'pat_ema')]
Opt time: 12.982976198196411
Current score: 0.22393181747644808
Best score so far: 0.22393181747644808
After 1 iterations, best obj is 0.2239 with params 0.05002500116825104
2
[('lr', 0.09783410604757857), ('n_estimators', 90), ('max_depth', 10), ('subsample', 0.9768197084360646), ('colsample_bytree', 0.9498448347173593), ('gamma', 4.438085471364415), ('min_child_weight', 4.949910987243103), ('fill_type', 'pat_ema')]
Opt time: 19.771536111831665
Current score: 0.27596342318731093
Best score so far: 0.27596342318731093
After 2 iterations, best obj is 0.2760 with params 0.09783410604757857
3
[('lr', 0.09534893514045893), ('n_estimators', 20), ('max_depth', 10), ('subsample', 0.958603449440306), ('colsample_bytree', 0.9897392694341535), ('gamma', 4.597594328222324), ('min_child_weight', 4.54544847187763), ('fill_type', 'pat_ema')]
Opt time: 13.476808786392212
Current score: 0.2911832440771226
Best score so far: 0.2911832440771226
After 3 iterations, best obj is 0.2912 with params 0.09534893514045893
4
[('lr', 0.08974096345257622), ('n_estimators', 30), ('max_depth', 10), ('subsample', 0.9875473709445614), ('colsample_bytree', 0.9899865032565488), ('gamma', 4.502709642709985), ('min_child_weight', 0.1675799138458133), ('fill_type', 'pat_ema')]
Opt time: 14.120468378067017
Current score: 0.281744205852656
Best score so far: 0.2911832440771226
After 4 iterations, best obj is 0.2912 with params 0.09534893514045893
5
[('lr', 0.056794210828607895), ('n_estimators', 100), ('max_depth', 5), ('subsample', 0.7841850110546084), ('colsample_bytree', 0.761872148068391), ('gamma', 4.660798376085397), ('min_child_weight', 1.706409948969703), ('fill_type', 'median')]
Opt time: 9.955632448196411
Current score: 0.31906698897163577
Best score so far: 0.31906698897163577
After 5 iterations, best obj is 0.3191 with params 0.056794210828607895
6
[('lr', 0.06194953234522322), ('n_estimators', 190), ('max_depth', 9), ('subsample', 0.5245152196754684), ('colsample_bytree', 0.7979904402789458), ('gamma', 4.999710831827684), ('min_child_weight', 1.31406233972266), ('fill_type', 'median')]
Opt time: 18.016623735427856
Current score: 0.23482889590088885
Best score so far: 0.31906698897163577
After 6 iterations, best obj is 0.3191 with params 0.056794210828607895
7
[('lr', 0.04779509898029476), ('n_estimators', 80), ('max_depth', 5), ('subsample', 0.7323350700727138), ('colsample_bytree', 0.7339707453149883), ('gamma', 4.681703025075991), ('min_child_weight', 1.731618124440871), ('fill_type', 'median')]
Opt time: 9.83765172958374
Current score: 0.3134999604020726
Best score so far: 0.31906698897163577
After 7 iterations, best obj is 0.3191 with params 0.056794210828607895
8
[('lr', 0.026268868648699047), ('n_estimators', 180), ('max_depth', 5), ('subsample', 0.6306731708933159), ('colsample_bytree', 0.7468692282491458), ('gamma', 3.7312811738115284), ('min_child_weight', 1.661273660312038), ('fill_type', 'median')]
Opt time: 11.977625846862793
Current score: 0.34071786913084745
Best score so far: 0.34071786913084745
After 8 iterations, best obj is 0.3407 with params 0.026268868648699047
9
[('lr', 0.017358057301979493), ('n_estimators', 190), ('max_depth', 5), ('subsample', 0.7265885258298752), ('colsample_bytree', 0.755437663728787), ('gamma', 3.294058784514677), ('min_child_weight', 2.0467435360563058), ('fill_type', 'median')]
Opt time: 11.51999807357788
Current score: 0.3195188613478257
Best score so far: 0.34071786913084745
After 9 iterations, best obj is 0.3407 with params 0.026268868648699047
10
[('lr', 0.043367372765737065), ('n_estimators', 200), ('max_depth', 4), ('subsample', 0.6074111596682685), ('colsample_bytree', 0.7127508642287075), ('gamma', 4.236515970760105), ('min_child_weight', 1.2258569949494702), ('fill_type', 'none')]
Opt time: 9.249167919158936
Current score: 0.3131071323890887
Best score so far: 0.34071786913084745
After 10 iterations, best obj is 0.3407 with params 0.026268868648699047
11
[('lr', 0.02077600494088448), ('n_estimators', 200), ('max_depth', 5), ('subsample', 0.8374782341247184), ('colsample_bytree', 0.545661817026894), ('gamma', 3.9953090379449208), ('min_child_weight', 0.7574154138166923), ('fill_type', 'median')]
Opt time: 11.864658832550049
Current score: 0.3306416509958666
Best score so far: 0.34071786913084745
After 11 iterations, best obj is 0.3407 with params 0.026268868648699047
12
[('lr', 0.06935039082110936), ('n_estimators', 50), ('max_depth', 5), ('subsample', 0.6314640561076242), ('colsample_bytree', 0.9636399285390094), ('gamma', 3.6190603244245025), ('min_child_weight', 3.0198621856406427), ('fill_type', 'median')]
Opt time: 9.606015682220459
Current score: 0.3283411937461205
Best score so far: 0.34071786913084745
After 12 iterations, best obj is 0.3407 with params 0.026268868648699047
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_2820306/4048605528.py in <module>
45
46 for i in range(opt_steps):
---> 47 rec = opt.suggest()
48 if "bs" in rec:
49 rec["bs"] = 2 ** rec["bs"]
~/.local/lib/python3.8/site-packages/hebo/optimizers/hebo.py in suggest(self, n_suggestions, fix_input)
151 sig = Sigma(model, linear_a = -1.)
152 opt = EvolutionOpt(self.space, acq, pop = 100, iters = 100, verbose = False, es=self.es)
--> 153 rec = opt.optimize(initial_suggest = best_x, fix_input = fix_input).drop_duplicates()
154 rec = rec[self.check_unique(rec)]
155
~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in optimize(self, initial_suggest, fix_input, return_pop)
125 crossover = self.get_crossover()
126 algo = get_algorithm(self.es, pop_size = self.pop, sampling = init_pop, mutation = mutation, crossover = crossover, repair = self.repair)
--> 127 res = minimize(prob, algo, ('n_gen', self.iter), verbose = self.verbose)
128 if res.X is not None and not return_pop:
129 opt_x = res.X.reshape(-1, len(lb)).astype(float)
~/.local/lib/python3.8/site-packages/pymoo/optimize.py in minimize(problem, algorithm, termination, copy_algorithm, copy_termination, **kwargs)
81
82 # actually execute the algorithm
---> 83 res = algorithm.run()
84
85 # store the deep copied algorithm in the result object
~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in run(self)
211 # while termination criterion not fulfilled
212 while self.has_next():
--> 213 self.next()
214
215 # create the result object to be returned
~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in next(self)
231 # call the advance with them after evaluation
232 if infills is not None:
--> 233 self.evaluator.eval(self.problem, infills, algorithm=self)
234 self.advance(infills=infills)
235
~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in eval(self, problem, pop, skip_already_evaluated, evaluate_values_of, count_evals, **kwargs)
93 # actually evaluate all solutions using the function that can be overwritten
94 if len(I) > 0:
---> 95 self._eval(problem, pop[I], evaluate_values_of=evaluate_values_of, **kwargs)
96
97 # set the feasibility attribute if cv exists
~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in _eval(self, problem, pop, evaluate_values_of, **kwargs)
110 evaluate_values_of = self.evaluate_values_of if evaluate_values_of is None else evaluate_values_of
111
--> 112 out = problem.evaluate(pop.get("X"),
113 return_values_of=evaluate_values_of,
114 return_as_dictionary=True,
~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in evaluate(self, X, return_values_of, return_as_dictionary, *args, **kwargs)
122
123 # do the actual evaluation for the given problem - calls in _evaluate method internally
--> 124 self.do(X, out, *args, **kwargs)
125
126 # make sure the array is 2d before doing the shape check
~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in do(self, X, out, *args, **kwargs)
160
161 def do(self, X, out, *args, **kwargs):
--> 162 self._evaluate(X, out, *args, **kwargs)
163 out_to_2d_ndarray(out)
164
~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in _evaluate(self, x, out, *args, **kwargs)
46
47 with torch.no_grad():
---> 48 acq_eval = self.acq(xcont, xenum).numpy().reshape(num_x, self.acq.num_obj + self.acq.num_constr)
49 out['F'] = acq_eval[:, :self.acq.num_obj]
50
~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in __call__(self, x, xe)
37
38 def __call__(self, x : Tensor, xe : Tensor):
---> 39 return self.eval(x, xe)
40
41 class SingleObjectiveAcq(Acquisition):
~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in eval(self, x, xe)
155 normed = ((self.tau - self.eps - py - noise * torch.randn(py.shape)) / ps)
156 dist = Normal(0., 1.)
--> 157 log_phi = dist.log_prob(normed)
158 Phi = dist.cdf(normed)
159 PI = Phi
~/.local/lib/python3.8/site-packages/torch/distributions/normal.py in log_prob(self, value)
71 def log_prob(self, value):
72 if self._validate_args:
---> 73 self._validate_sample(value)
74 # compute the variance
75 var = (self.scale ** 2)
~/.local/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
286 valid = support.check(value)
287 if not valid.all():
--> 288 raise ValueError(
289 "Expected value argument "
290 f"({type(value).__name__} of shape {tuple(value.shape)}) "
ValueError: Expected value argument (Tensor of shape (100, 1)) to be within the support (Real()) of the distribution Normal(loc: 0.0, scale: 1.0), but found invalid values:
tensor([[-1.9959],
[-1.2675],
[-1.2204],
[-0.5946],
[-1.3163],
[-0.8091],
[-2.3450],
[-1.1690],
[-1.2374],
[-0.5374],
[-0.8852],
[-1.5104],
[-1.8167],
[ 0.3373],
[-1.0077],
[-1.5388],
[ 0.9909],
[-0.9809],
[-1.0140],
[-0.1807],
[-0.5176],
[-0.3398],
[-1.5057],
[-1.3493],
[-1.3827],
[-0.7947],
[-2.6809],
[-0.7844],
[-1.4292],
[-0.8269],
[-1.6755],
[-1.6348],
[-0.7895],
[-0.8264],
[-1.3902],
[-0.5924],
[-1.4093],
[-0.8154],
[ 0.2801],
[-0.6707],
[-1.0585],
[-1.5289],
[-1.2883],
[-0.6418],
[-3.6011],
[ nan],
[-1.3098],
[-2.6957],
[-0.9912],
[ 0.4284],
[-1.6822],
[-0.5964],
[-0.1601],
[-1.2632],
[-0.8173],
[-0.1966],
[ 1.8093],
[ 0.5075],
[-0.6223],
[-1.1435],
[-0.7424],
[-1.6756],
[ 1.7556],
[-1.5124],
[-1.4938],
[-0.6549],
[-0.6919],
[-0.4789],
[-1.6914],
[-1.8472],
[-0.3958],
[-1.9369],
[-1.5689],
[-0.7813],
[-0.8114],
[-0.9482],
[-0.9427],
[-1.5766],
[-0.6994],
[-1.2480],
[-1.1529],
[-1.0359],
[-1.6211],
[-1.1925],
[-0.7662],
[-0.9530],
[-0.0925],
[ 0.1829],
[-1.6802],
[-1.7956],
[-1.6634],
[-1.8606],
[-1.1047],
[-0.5844],
[-1.0566],
[-1.6968],
[-0.9914],
[-0.8555],
[-1.4518],
[-1.6394]])
And opt.y
is:
array([[0.93496978],
[0.77606818],
[0.72403658],
[0.70881676],
[0.71825579],
[0.68093301],
[0.7651711 ],
[0.68650004],
[0.65928213],
[0.68048114],
[0.68689287],
[0.66935835],
[0.67165881]])
I'm quite sure that this error is not dependent on XGBoost because it also happens when I train an RNN. It even crashed there on the third suggest
step.
I really don't know what is happening here. I'm now trying different surrogate models. As far as I can see there are gpy
(default), gp
, gpy_mlp
and rf
. As the error seems related to the Gaussian process I'm trying rf
first. But are there any metrics and evaluations of how well this performs? I could not really find anything in the arXiv paper at https://arxiv.org/pdf/2012.03826.pdf
rf
did not crash so far, even after 70 steps - but if it performs worse then I of course don't want to use it.
10
[('lr', 0.043367372765737065), ('n_estimators', 200), ('max_depth', 4), ('subsample', 0.6074111596682685), ('colsample_bytree', 0.7127508642287075), ('gamma', 4.236515970760105), ('min_child_weight', 1.2258569949494702), ('fill_type', 'none')]
Why is the fill_type
being 'none'
?
@MdAsifKhan I think I have found the reason, it's because these lines of code
if "bs" in rec:
rec["bs"] = 2 ** rec["bs"]
if "n_estimators" in rec:
rec["n_estimators"] *= 10
By doing this, you modified rec
, so the rec
you passed to observe
is not same with the one returned by suggest
, for example, your n_estimators
is defined within [1,20]
, but the n_estimators
you passed to observe
would be within [10,200]
It looks like that you want n_estimators
to be multiples of 10, and you want bs
to be integer power of 2. Actually HEBO has built-in support for these requirement so you don't need to do the manual transformation.
You can write the space configurations like this
import pandas as pd
import numpy as np
from hebo.design_space import DesignSpace
from hebo.optimizers.hebo import HEBO
np.random.seed(42)
space = DesignSpace().parse([
{'name': 'n_estimators', 'type' : 'step_int', 'lb' : 10, 'ub' : 200, 'step' : 10}, # multiplied by 10
{'name': 'bs', 'type' : 'int_exponent', 'lb' : 16, 'ub' : 1024, 'base' : 2}, # 2**(int)
])
print(space.sample(10))
The output would be like
n_estimators bs
0 70 64
1 200 512
2 150 256
3 110 32
4 80 128
5 70 512
6 190 512
7 110 32
8 110 128
9 40 256
Ah thank you so much! That fixes the issue for me. Possibly you could raise a warning if the rec
given to obj differs from the last output of suggest
?
Also thanks for the distribution tips - I wanted to do it in my way as I want to compare optuna and HEBO and therefore want to modify these parameters independently.
Hi, thanks for this repository! So far it works quite well, but now I suddenly encountered a weird error after 11 optimization steps of non-batched HEBO:
Seems like there is a NaN in some distribution of HEBO. But my input parameters (opt.X) and losses (opt.y) are never NaN. This is the design space I'm using:
I already commented out
flat_block_size
as I thought that maybe it is a problem iflb == ub
, but it still crashes.Any ideas on how I can debug this?