facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.31k stars 294 forks source link

Multi-task BO with Service API #2546

Open sgbaird opened 1 week ago

sgbaird commented 1 week ago

Should I perish the thought?

xref: #1038

Fairly naive, non-functional starting point:

import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties

from ax.modelbridge.factory import Models
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy

from ax.modelbridge.registry import ST_MTGP_trans

obj1_name = "branin"

def branin(x1, x2):
    y = float(
        (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
        + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
        + 10
    )

    return y

gs = GenerationStrategy(
    steps=[
        GenerationStep(
            model=Models.ST_MTGP,
            num_trials=-1,
            max_parallelism=3,
            model_kwargs={"transforms": ST_MTGP_trans, "transform_configs": None},
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs, random_seed=42)

ax_client.create_experiment(
    parameters=[
        {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
        {"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
    ],
    objectives={
        obj1_name: ObjectiveProperties(minimize=True),
    },
)

for _ in range(10):

    parameterization, trial_index = ax_client.get_next_trial()

    # extract parameters
    x1 = parameterization["x1"]
    x2 = parameterization["x2"]

    results = branin(x1, x2)
    ax_client.complete_trial(trial_index=trial_index, raw_data=results)

best_parameters, metrics = ax_client.get_best_parameters()
(honegumi) PS C:\Users\sterg\Documents\GitHub\sgbaird\honegumi> & C:/Users/sterg/miniforge3/envs/honegumi/python.exe c:/Users/sterg/Documents/GitHub/sgbaird/honegumi/scripts/refreshers/multi_task.py
[INFO 06-26 16:25:51] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[WARNING 06-26 16:25:51] ax.service.ax_client: Random seed set to 42. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0])], parameter_constraints=[]).
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': -3.0, 'x2': 5.0}] as trial 0.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 0 with data: {'branin': (48.620235, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 0.0, 'x2': 6.2}] as trial 1.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 1 with data: {'branin': (19.642113, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 5.9, 'x2': 2.0}] as trial 2.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 2 with data: {'branin': (19.70361, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 1.5, 'x2': 2.0}] as trial 3.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 3 with data: {'branin': (14.301934, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 1.0, 'x2': 9.0}] as trial 4.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 4 with data: {'branin': (35.100744, None)}.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '0') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '1') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '2') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '3') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '4') is constant, within tolerance.
C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ..\torch\csrc\utils\tensor_new.cpp:620.)
  summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\acquisition\cached_cholesky.py:89: RuntimeWarning: `cache_root` is only supported for GPyTorchModels that are not MultiTask models and don't produce a TransformedPosterior. Got a model of type <class 'botorch.models.model_list_gp_regression.ModelListGP'>. Setting `cache_root = False`.
  warnings.warn(
Traceback (most recent call last):
  File "c:\Users\sterg\Documents\GitHub\sgbaird\honegumi\scripts\refreshers\multi_task.py", line 114, in <module>
    parameterization, trial_index = ax_client.get_next_trial()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\utils\common\executils.py", line 163, in actual_wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\service\ax_client.py", line 539, in get_next_trial
    generator_run=self._gen_new_generator_run(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\service\ax_client.py", line 1790, in _gen_new_generator_run
    return not_none(self.generation_strategy).gen(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_strategy.py", line 370, in gen
    return self._gen_multiple(
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_strategy.py", line 683, in _gen_multiple
    generator_run = self._curr.gen(
                    ^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 712, in gen
    gr = super().gen(
         ^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 272, in gen
    generator_run = self._gen(
                    ^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 334, in _gen
    return model_spec.gen(
           ^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\model_spec.py", line 221, in gen
    return fitted_model.gen(**model_gen_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\base.py", line 786, in gen
    gen_results = self._gen(
                  ^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\torch.py", line 686, in _gen
    gen_results = not_none(self.model).gen(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch_modular\model.py", line 428, in gen
    candidates, expected_acquisition_value, weights = acqf.optimize(
                                                      ^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch_modular\acquisition.py", line 450, in optimize
    candidates, acqf_values = optimize_acqf(
                              ^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 567, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 588, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 400, in _optimize_acqf_batch
    batch_candidates = opt_inputs.post_processing_func(batch_candidates)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch.py", line 531, in botorch_rounding_func
    [rounding_func(x) for x in X.view(-1, d)]  # pyre-ignore: [16]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch.py", line 531, in <listcomp>
    [rounding_func(x) for x in X.view(-1, d)]  # pyre-ignore: [16]
     ^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\torch.py", line 287, in <lambda>
    self._array_to_tensor(array_func(x.detach().cpu().clone().numpy()))
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\modelbridge_utils.py", line 664, in _roundtrip_transform
    observation_features = t.untransform_observation_features(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\transforms\choice_encode.py", line 122, in untransform_observation_features
    obsf.parameters[p_name] = reverse_transform[pval]
                              ~~~~~~~~~~~~~~~~~^^^^^^
KeyError: 3.7423404678702354

Would this be something for BOTORCH_MODULAR in a custom generation strategy instead?

danielcohenlive commented 1 week ago

Great question @sgbaird! This is something we do internally in the service API with batch trials. We do have future plans to open source our AxBatchClient, but it's unfortunately not out yet.

With batch trials, you'd have a GS consisting of

  1. SOBOL
  2. GPEI or BOTORCH_MODULAR without fixed_features and status_quo_features
    • At this point, there's only one trial, so you can't do multitask yet.
  3. ST_MTGP or BOTORCH_MODULAR with fixed_features and status_quo_features The fixed_features and status_quo_features are going to point to a trial index, so you'd want those to both point to the target trial, probably the most recent one. I'm not aware of any way to group non batch trials into tasks.

What are you trying to do? I noticed honegumi in the prompt. Is this for the honegumi interface or a real world use case? Is it intentional or accidental that this use case has non batch trials?