facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.34k stars 303 forks source link

Pre-fit ModelListGP and pass it to Models.BOTORCH_MODULAR in Ax Client #2471

Closed Runyu-Zhang closed 2 months ago

Runyu-Zhang commented 3 months ago

Hello, thank you all for creating and maintaining this wonderful Ax library for BO. I have a question regarding pre-fitting a GP model and pass it to Ax Client for multi-objective optimization. Similar questions can be found: https://github.com/pytorch/botorch/issues/2299, https://github.com/pytorch/botorch/issues/1750. I followed @Balandat suggestion https://github.com/facebook/Ax/issues/1647#issuecomment-1581428813, to use the Modular BoTorch Model interface to pass in a new custom BoTorch class together with a pre-fitted model (Surrogate?). I create one Sobol step in GenerationStrategy just to avoid "ax.exceptions.core.datarequirederror: standardizey transform requires non-empty data.". However, I still got error:

suggested_motor_positions, trial_index = ax_client_bom.get_next_trial() File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/utils/common/executils.py", line 161, in actual_wrapper return func(*args, kwargs) File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/service/ax_client.py", line 531, in get_next_trial generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/service/ax_client.py", line 1763, in _gen_new_generator_run return not_none(self.generation_strategy).gen( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/generation_strategy.py", line 478, in gen return self._gen_multiple( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/generation_strategy.py", line 662, in _gen_multiple self._fit_current_model(data=data) File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/generation_strategy.py", line 723, in _fit_current_model self._curr.fit(experiment=self.experiment, data=data, model_state_on_lgr) File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/generation_node.py", line 262, in fit model_spec.fit( # Stores the fitted model as model_spec._fitted_model File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/model_spec.py", line 146, in fit self._fitted_model = self.model_enum( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/registry.py", line 380, in call model_bridge = bridge_class( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/torch.py", line 129, in init super().init( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/base.py", line 208, in init self._fit_if_implemented( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/base.py", line 231, in _fit_if_implemented self._fit( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/modelbridge/torch.py", line 647, in _fit self.model.fit( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/models/torch/botorch_modular/model.py", line 290, in fit surrogate.fit( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/models/torch/botorch_modular/surrogate.py", line 536, in fit should_use_model_list = use_model_list( File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/site-packages/ax/models/torch/botorch_modular/utils.py", line 47, in use_model_list if issubclass(botorch_model_class, MultiTaskGP): File "/Users/runyu/miniconda3/envs/py3919/lib/python3.9/abc.py", line 123, in subclasscheck return _abc_subclasscheck(cls, subclass) TypeError: issubclass() arg 1 must be a class

which is likely due related to botorch_model_class in Surrogate() cannot be a ModelList. Is there a way to do this? Basically pass a pre-fitted multi-objective GP model to Ax and continue the MOBO process? My code implementation:

`def create_gp_model_kwargs(self): return {'surrogate' : Surrogate(botorch_model_class = train_default_gp_model()), 'botorch_acqf_class': qNoisyExpectedHypervolumeImprovement}

def train_default_gp_model(self): train_x, train_y = initialize_training_data() train_x = normalize(X=train_x_motors, bounds=params_bounds_tensor) models = [] for i in range(train_y.shape[-1]): models.append(SingleTaskGP(train_X = train_x, train_Y = train_y[:, i].unsqueeze(-1), outcome_transform = Standardize(m = 1))) model = ModelListGP(*models) mll = SumMarginalLogLikelihood(model.likelihood, model) fit_gpytorch_mll(mll) return model

parameters_list = [dict(name = param, type = 'range', bounds = range_values, value_type = 'float') for param in all_parameters]

objectives = {obj: ObjectiveProperties(minimize=True, threshold=hypervolume_reference_point) for obj in all_objectives}

steps = [GenerationStep(model = Models.SOBOL, # Add 1 sobol to avoid "ax.exceptions.core.DataRequiredError: StandardizeY transform requires non-empty data." num_trials = 1), GenerationStep(model. = Models.BOTORCH_MODULAR, num_trials. = num_bo_trials, model_kwargs = create_gp_model_kwargs())]

ax_client = AxClient(generation_strategy = GenerationStrategy(steps=steps_bom), random_seed = fixed_seed, verbose_logging = True) ax_client.create_experiment(name = 'Ax_Modular', parameters = parameters_list, objectives = objectives, overwrite_existing_experiment = True)

for i in range(BO_trials): suggested_values, trial_index = ax_client.get_next_trial() loss_values = evaluate_suggested_values(suggested_values=suggested_values) ax_client.complete_trial(trial_index=trial_index, raw_data=loss_values.copy())`

Runyu-Zhang commented 3 months ago

I reformatted the code.

Code


def create_gp_model_kwargs():
    return {'surrogate' : Surrogate(botorch_model_class = train_default_gp_model()),

            'botorch_acqf_class': qNoisyExpectedHypervolumeImprovement}

def train_default_gp_model(self):

    train_x, train_y = initialize_training_data()

    train_x = normalize(X=train_x_motors, bounds=params_bounds_tensor)

    models = []

    for i in range(train_y.shape[-1]):

        models.append(SingleTaskGP(train_X = train_x,

                                   train_Y = train_y[:, i].unsqueeze(-1),

                                   outcome_transform = Standardize(m = 1)))

        model = ModelListGP(*models)

        mll = SumMarginalLogLikelihood(model.likelihood, model)

        fit_gpytorch_mll(mll)

    return model

parameters_list = [dict(name = param,
type = 'range',
bounds = range_values,
value_type = 'float') 
                   for param in all_parameters]
objectives = {obj: ObjectiveProperties(minimize=True, threshold=hypervolume_reference_point) 
              for obj in all_objectives}
steps = [GenerationStep(model = Models.SOBOL, # Add 1 sobol to avoid "ax.exceptions.core.DataRequiredError: StandardizeY transform requires non-empty data."

                        num_trials = 1),

         GenerationStep(model = Models.BOTORCH_MODULAR,

                        num_trials. = num_bo_trials,

                        model_kwargs = create_gp_model_kwargs())]

# Create Ax Client and optimize:

ax_client = AxClient(generation_strategy = GenerationStrategy(steps=steps),

                     random_seed = fixed_seed,

                     verbose_logging = True)

ax_client.create_experiment(name = 'Ax_Modular',

                            parameters = parameters_list,

                            objectives = objectives,

                            overwrite_existing_experiment = True)
for i in range(BO_trials):

    suggested_values, trial_index = ax_client.get_next_trial()
    # Throw the error after one Sobol trial
    loss_values = evaluate_suggested_values(suggested_values=suggested_values)
 
    ax_client.complete_trial(trial_index=trial_index, raw_data=loss_values.copy())
Cesar-Cardoso commented 3 months ago

Hello there! In this case you're getting an exception because botorch_model_class should be a class type rather than an instance. Your train_default_gp_model() method returns an instance of a Model (specifically a ModelListGP). When this check happens you get an exception.

Perhaps what you want to do is pass a callable model argument to GenerationStep as described here?

saitcakmak commented 3 months ago

Hi @Runyu-Zhang. We do not have proper support for this use case -- we used to have some support but we removed it since it was very difficult to use correctly. The Ax Model & ModelBridge layer is designed to transform the data into BoTorch datasets, use these to construct & fit the BoTorch models. Why are you trying to utilize a pre-trained model rather than letting Ax train the model using the same fit_gpytorch_mll method? Do you need to pass in certain arguments to the BoTorch model that the API currently doesn't support? If so, I'd like to learn about these and see if it'd make sense for us to support them more generally. In general, I'd strongly recommend following this tutorial (https://botorch.org/tutorials/custom_botorch_model_in_ax) to create a custom BoTorch model class and let Ax do the training.

RyanSaat commented 2 months ago

@saitcakmak Hi Saitcakmak. I also find retraining a pre-trained model useful because there are times when you might feel the model is not good and you wish to train it further.

saitcakmak commented 2 months ago

Hi @RyanSaat. Since Ax re-constructs the model using the up-to-date data in each iteration, I don't know what a pre-trained model would look like in here (since it would have to be different in each iteration). The models are trained using fit_gpytorch_mll helper from BoTorch, which uses L-BFGS-B optimizer under the hood to find a local optimum of the model hyper-parameters. Since the trained model is at a local optimum, re-training it using the same optimizer typically does not result in any difference. Though, it would be possible to attempt model training with different initial conditions.

One option we could offer here is to let the user customize how the model is trained. Instead of calling fit_gpytorch_mll (here), we could allow the user to register custom routines for model fitting (I guess this is already possible by registering a dispatcher case in fit_gpytorch_mll directly). You could then write a custom model fitting helper that fits the model, checks its quality and processes it further as needed, before returning the trained mll. Would this address your needs?

RyanSaat commented 2 months ago

@saitcakmak Thank you for your response. It's very helpful! In my case, I already have a large dataset that I intend to use for pre-training a GP model. Conducting a new trial is time-consuming, and I don't want to waste any data. However, this seems to be incompatible with the suggested workflow.

saitcakmak commented 2 months ago

@RyanSaat Could you just attach the dataset you have to the Ax experiment as manual trials? This should work if the data comes from the same experiment. If the dataset originates from a different experiment, you could use a multi-task setup -- though this is not well documented at the moment.

I also don't fully understand what the intended workflow with pre-training looks like in this case. Do we just carry on the pre-trained hyper-parameters but not the data or does the model carry on with the pre-training data (which is then appended with the experiment data). Do we further train the model after adding the experiment data or do we want to keep the hyper-parameters same as the pre-trained values? A pseudo code of the workflow would be helpful

RyanSaat commented 2 months ago

@saitcakmak I am currently using numerical simulations to help me optimize the performance of a mechanism I designed. I have conducted numerous simulation runs, gathering performance metrics of the mechanism under various design parameters. However, due to the large design space and the lengthy duration of each simulation run, exhaustively exploring the design space through brute force is impractical. Thus, I am turning to Bayesian Optimization.

Having already accumulated many preliminary data, I want to use this data to better initialize the GP model, thereby reducing the optimization time. My envisioned workflow is as follows:

(1) Collect the simulation data I have obtained. This data includes the design parameters, X, which serve as inputs for the simulation, and the corresponding performance metrics of the mechanism, quantified by a scalar value.

(2) Using the existing data, configure the kernel of the GP model (e.g., length scale of each dimension) and pre-train the GP model.

(3) Employ the acquisition function to predict the next test point, then conduct another simulation.

(4) Use the results of this simulation to update the GP model.

(4) Repeat steps (3) and (4) until the desired number of iterations is reached.

saitcakmak commented 2 months ago

What you want to do can be achieved by attaching the previous data you collected to the Ax experiment, without having to go around the modeling layer to add a custom pre-training step. If you attach the data to the experiment, it will be included in the training data for the GP, which then gets used to generate the candidates.

To do this, you can use AxClient.attach_trial (tutorial example) for each previous simulation and complete these trials (AxClient.complete_trial) with the observations you already collected. Once all previous data is attached (& completed), you can generate the next trial using AxClient.get_next_trial.

RyanSaat commented 2 months ago

Aha, I see. Thank you!

What you want to do can be achieved by attaching the previous data you collected to the Ax experiment, without having to go around the modeling layer to add a custom pre-training step. If you attach the data to the experiment, it will be included in the training data for the GP, which then gets used to generate the candidates.

To do this, you can use AxClient.attach_trial (tutorial example) for each previous simulation and complete these trials (AxClient.complete_trial) with the observations you already collected. Once all previous data is attached (& completed), you can generate the next trial using AxClient.get_next_trial.