Closed gentrexha closed 2 years ago
Hi, sorry for our late reply. Indeed it's not well documented how to use instances. We will add this documentation shortly.
There are basically two ways how to include instances:
Additionally to the instances, you can add instance features, which are passed as dict to the scenario object, e.g. {'features': {0: (15, 15, 15), 1: (15, 15, 16)}}. Those features are used to train the surrogate model (see https://github.com/automl/SMAC3/blob/master/smac/runhistory/runhistory2epm.py#L122).
Finally, you can use those two example to integrate the instances:
Here's another example:
import typing
import uuid
import shutil
import logging
from itertools import chain, combinations
from ConfigSpace.hyperparameters import CategoricalHyperparameter, UniformFloatHyperparameter, Constant
from ConfigSpace.conditions import InCondition, AndConjunction
from smac.configspace import ConfigurationSpace, Configuration
from smac.facade.smac_hpo_facade import SMAC4HPO
from smac.scenario.scenario import Scenario
from smac.initial_design.latin_hypercube_design import LHDesign
import os
import sys
from os.path import dirname, abspath
parent_path = dirname(dirname(abspath(__file__)))
sys.path.append(parent_path)
from bobo.optimizer import SMAC4EPMOpimizer
from bobo.meta_space import _add_gp_map_cs, _add_gp_ml_cs, _add_ytrans_cs, _add_init_design_cs, _add_rf_cs
from bobo.run_bayesmark import run_experiment
from bobo.meta_score import compute_score
from bobo.tae_bbo import ExcuteTABBO
logger = logging.getLogger(__name__)
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="smac_bobo")
parser.add_argument('--smac_model', type=str, choices={'RF', 'GP', 'GPmap', 'RF_GP', 'RF_GPmap'}, default='GP')
return parser.parse_args()
# https://stackoverflow.com/a/1482316
def powerset(iterable):
"""powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))
def run_target_bo(config: Configuration,
instance: str,
db_root: str,
db_id: str,
instance_to_uuid: typing.Dict[typing.Tuple[str, str, str], typing.List[str]],
run_uuid):
db_root = os.path.join(parent_path, db_root)
model_name, dataset, metric = instance.split('_', 2)
opt_name = 'smac-ensemble'
run_experiment(opt_class=SMAC4EPMOpimizer,
opt_kwargs={'config': config},
opt_name=opt_name,
db_root=db_root,
db_id=db_id,
instance_to_uuid=instance_to_uuid,
run_uuid=run_uuid,
model_name=model_name,
dataset=dataset,
scorer=metric)
cost = compute_score(db_root=db_root,
db_id=db_id,
instance_to_uuid=instance_to_uuid,
model_names=[model_name],
datasets=[dataset],
metrics=[metric])
logger.info("get score {:.4f} on dataset {}, classifier {} with metric {}".format(100-cost, dataset, model_name, metric))
#here cost is actually 100 - score
return cost
def build_config_space(smac_model: str, acq_funcs, model_names, datasets):
#smac_model_list = list(powerset(smac_models))
#smac_model_list = ['_'.join(smac_model) for smac_model in smac_model_list]
acq_func_list = list(powerset(acq_funcs))
acq_func_list = ['_'.join(acq_func) for acq_func in acq_func_list]
cs = ConfigurationSpace()
hp_smac_model = Constant('smac_models', smac_model)
smac_acq_func = CategoricalHyperparameter('acq_funcs', choices=acq_func_list)
if 'RF' in smac_model:
hps_rf = _add_rf_cs(cs)
if 'GP' in smac_model:
hps_gpml = _add_gp_ml_cs(cs)
if 'GPmap' in smac_model:
hps_gpmap_unconditioned, hps_gpmap_conditioned, conds_gpmap = _add_gp_map_cs(cs)
cs.add_hyperparameters([hp_smac_model, smac_acq_func])
lcb_par = UniformFloatHyperparameter("LCB_par", 10 ** -2, 10 ** 2, default_value=0.1, log=True)
ei_par = UniformFloatHyperparameter("EI_par", 0, 1, default_value=0)
pi_par = UniformFloatHyperparameter("PI_par", 0, 1, default_value=0)
logei_par = UniformFloatHyperparameter("logEI_par", 0, 1, default_value=0)
cond_lcb = InCondition(lcb_par, smac_acq_func, [lcb for lcb in acq_func_list if 'LCB' in lcb])
cond_ei = InCondition(ei_par, smac_acq_func, [ei for ei in acq_func_list if 'EI' in ei.split('_')])
cond_pi = InCondition(pi_par, smac_acq_func, [pi for pi in acq_func_list if 'PI' in pi])
cond_logei = InCondition(logei_par, smac_acq_func, [logei for logei in acq_func_list if 'logEI' in logei])
cs.add_hyperparameters([lcb_par, ei_par, pi_par, logei_par])
cs.add_conditions([cond_lcb, cond_ei, cond_pi, cond_logei])
y_trans = _add_ytrans_cs(cs)
_add_init_design_cs(cs)
rand_prob = UniformFloatHyperparameter('rand_prob', 0, 0.5, default_value=0.25)
parallel_setting = CategoricalHyperparameter('parallel_setting',
choices=['CL_min', 'CL_max', 'CL_mean', 'KB', 'LS'],
default_value='LS')
cs.add_hyperparameters([rand_prob, parallel_setting])
instances = []
instance_features = {}
metric_dict = {'mae': 0,
'mse': 1,
'acc': 2,
'nll': 3}
for i, model_name in enumerate(model_names):
for j, dataset in enumerate(datasets.keys()):
for metric in datasets[dataset]:
instance = '_'.join((model_name, dataset, metric))
instances.append([instance])
instance_features[instance] = (i, j, metric_dict[metric])
scenario = Scenario({'run_obj': 'quality',
"algo_runs_timelimit": 3600 * 24*2, # 2 days
'cs': cs,
'runcount-limit': 5000,
'deterministic': False,
'instances': instances,
'features': instance_features,
'maxR': 3,
'output_dir': os.path.join(parent_path, 'bobo'),
'limit_resources': False
})
tae_runner_kwargs = {'ta': run_target_bo,
'scenario': scenario}
tae_runner = ExcuteTABBO
initial_design_kwargs = {'n_configs_x_params': 2}
smac = SMAC4HPO(scenario=scenario,
initial_design=LHDesign,
tae_runner=tae_runner,
tae_runner_kwargs=tae_runner_kwargs,
initial_design_kwargs=initial_design_kwargs
)
return smac
def main():
args = parse_args()
smac_models = ['RF', 'GP', 'GPmap']
smac_model_list = ['RF', 'GP', 'GPmap', 'RF_GP', 'RF_GPmap']
acq_funcs = ["LCB", "EI", "PI", "logEI"]
datasets = {'boston': ['mae', 'mse'],
'breast': ['acc', 'nll'],
'diabetes': ['mae', 'mse'],
'digits': ['acc', 'nll'],
'iris': ['acc', 'nll'],
'wine': ['acc', 'nll']}
model_names = ['ada', 'DT', 'kNN', 'lasso', 'linear', 'MLP-adam', 'MLP-sgd', 'RF', 'SVM']
smac = build_config_space(smac_model=args.smac_model,
acq_funcs=acq_funcs,
model_names=model_names,
datasets=datasets)
baseline = os.path.join(parent_path, 'input/baseline-16-8.json')
os.makedirs(os.path.join(smac.output_dir, 'derived'))
shutil.copy(baseline, os.path.join(smac.output_dir, 'derived', 'baseline.json'))
incumbent = smac.optimize()
print("final incumbent")
print(incumbent.get_dictionary())
if __name__ == '__main__':
main()
Best, René
Description
I've created my version of the Simulated Annealing algorithm to solve AutoML problems. Now I'm trying to optimize my initial Temperature parameter. For this, I've taken the OpenML CC-18 Benchmark Suite and started running SMAC for each dataset in the benchmark.
Resulting in an optimal temperature value different for each dataset. I want to find an optimized temperature value for all datasets. Is there a way to give the datasets as a parameter to my scenario and let SMAC try out different datasets?
I haven't found much useful material regarding this (ref: #454 & smac_ac_facade docs). It is clear to me that SMAC4HPO does not support instances. But on the other hand, while SMAC4AC does, I am not sure how to make use of that.
Steps/Code to Reproduce
As of now, I am using SMAC4HPO and iterating over the datasets in the benchmark suite like this:
Inside my
sa_from_cfg
function I just load the dataset fromopenml
through thetask_id
, train-test-split it, and call mySimulatedAnneal.fit(X_train, y_train)
function which does all the AutoML work and returns the F1 score of the test data.Expected Results
Provide a list of datasets, or directory containing datasets, and SMAC returns the optimal value for all instances.
Actual Results
Currently, I am receiving one optimal temperature value per dataset, while providing the dataset as a
Constant
. I could choose thetask_id
randomly between all the options, and therefore optimizing my algorithm over "all" instances, but this feels very hacky to me. Therefore, I wanted to know if there are any other alternatives to this?Versions
0.13.1