facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.38k stars 312 forks source link

Support batch trial in service API #129

Closed dongyaoli closed 5 years ago

dongyaoli commented 5 years ago

First of all I really appreciate the great work that has been done here and the fact that this library is open sourced.

In my use case, I would like to do Bayesian Optimization of the hyperparameter of neural networks. Each training of the neural networks can take more than 10 hours. The training is submitted to the gpu cluster through slurm system. Because the training takes such long time, I would want to run multiple training (arms) at the same time.

Right now, because the service API doesn't support batch trial, the optimization loop I set up using ax is to create an Experiment to manage data, and write everything else separately, including initialize batch trial, evaluate them through slurm system, collect results, joint optimize by botorch, record the trial in experiment, and repeat.

If the service API is intended to be used for the cases when trials are evaluated externally, I would like to request a feature to make service API support batch trial. I am than happy to contribute to the implementation if possible. If so, I would appreciate any guidance in terms of how the core development team would like this to be done.

lena-kashtelyan commented 5 years ago

@dongyaoli, thank you so much for the feedback! Just so that I better understand your use case, what is making the solution of 1) generating N (where N is the desired batch size) trials at once, 2) evaluating them, 3) completing them, and 4) repeating from step 1 –– an unfitting one?

dongyaoli commented 5 years ago

@dongyaoli, thank you so much for the feedback! Just so that I better understand your use case, what is making the solution of 1) generating N (where N is the desired batch size) trials at once, 2) evaluating them, 3) completing them, and 4) repeating from step 1 –– an unfitting one?

@lena-kashtelyan maybe I didn't explain myself clearly. The flow you described is indeed the flow we want. And note that the evaluation happens externally in other gpu machines requested through slurm system. My understanding now is that the Experiment api support batch trial but not external evaluation. And the service api (which is probably designed to work for trial + external evaluation) cannot support batch trial right now. Is my understanding of the library correct? Thanks!

lena-kashtelyan commented 5 years ago

I think it's me who wasn't fully clear, my apologies. In step 1), I meant just calling ax_client.get_next_trial N times to generate the N trials you would want to deploy at once.

dongyaoli commented 5 years ago

I think it's me who wasn't fully clear, my apologies. In step 1), I meant just calling ax_client.get_next_trial N times to generate the N trials you would want to deploy at once.

I this this is related with the optimization algorithms. For parallel Bayesian Optimization, the next batch of arms (I guess in Ax terminology, arms are the right term here? multiple arms will form the next trial) to be evaluated are generated at the same time (for example using qExpectedImprovement and joint_optimize in botorch). If I generate a single arm every time for multiple times, these arms will be exact same. My understanding of the current AxClient implementation is that it only support generate new trial with 1 arm. If I do this multiple times with the same training data, it should give me the exact same arm and trial for multiple times. Is my understanding wrong somewhere? Thank you for the prompt response!

lena-kashtelyan commented 5 years ago

@dongyaoli, the arms actually will not be the same; the trials that have not yet been completed with data are considered pending, so the arms in them are not re-suggested. So just generating N trials should work for your case. An added bonus you will get this way is that every new arm generated will be chosen given all the freshest training data available, rather than given only the training data that was available when the whole batch was generated. Also, you wouldn't need to wait for the whole batch to complete before generating some more trials.

However, make sure to use the master version instead of the stable version, because there was an important fix to the pending observations functionality in the Service API that has not yet made it into the stable version.

Balandat commented 5 years ago

@dongyaoli To weigh in on this, if you do what @lena-kashtelyan suggest, you effectively do sequential conditioning on the pending arms in later calls to get_next_trial. That is, the previously suggested arms will be passed to BoTorch under the hood, and the acquisition function will correctly account for them in generating the new arm by integrating over the model uncertainty at the suggested arms.

This is effectively a greedy approach to the full joint generation of q candidates in BoTorch's qExpectedImprovement. But there are some submodularity results that show that you don't lose much by doing that (see this paper). In fact, in cases of high parallelism (i.e. if q is large), then doing this might give better performance, simply because the optimization problems that need to be solved are much easier than the fully parallel problem.

lena-kashtelyan commented 5 years ago

Just to stress this important point, @Balandat's response applies to the version of the Service API that is on master, not the stable version you would pull from pip.

dongyaoli commented 5 years ago

@Balandat @lena-kashtelyan Thank you for your detailed explanation! I have been reading through theories and source code in the past few days. I do want to clarify a few things in terms of how botorch is handling the optimization. But I will ask these questions in the botorch git repo. I don't fully understand how Ax registries and uses botorch model yet. Nevertheless I still rewrote the optimization loop using AxClient as you suggested. I am putting the code here (with some pseudo code) to check with you if I am doing things correctly.

What I did before

from ax import SearchSpace
from ax.modelbridge import get_sobol
from botorch.models import SingleTaskGP
from botorch import fit_gpytorch_model
from botorch.acquisition import qExpectedImprovement

# assume parameters are already created and batch_evaulation_function is defined elsewhere
search_space = SearchSpace(parameters=parameters)
sobol = get_sobol(search_space)
# I attach the arms into an Experiment but I am not showing those code
random_arm_list = sobol.gen(init_trial_num).arms  

train_x = torch.FloatTensor(arms_to_array(random_arm_list))
train_y = batch_evaulation_function(train_x)
model = SingleTaskGP(train_X=train_x, train_Y=train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)
best_value = train_y.max()
sampler = SobolQMCNormalSampler(num_samples=500, seed=42, resample=False)

for i in range(num_trials):
    MC_EI = qExpectedImprovement(model, best_f=best_value, sampler=sampler)
    candidates = joint_optimize(
            acq_function=MC_EI,
            bounds=torch.FloatTensor(bounds),
            q=5,
            num_restarts=20,
            raw_samples=500,
            options={},
        )
    new_x = candidates.detach()
    new_y = batch_evaulation_function(train_x)
    train_x = torch.cat([train_x, new_x])
    train_y = torch.cat([train_y, new_y])
    model = SingleTaskGP(train_X=train_x, train_Y=train_obj)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    best_value = train_obj.max()

what I am doing now with ax_client

from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
import subprocess

def batch_evaluation_service(ax_client, batch_trial_num):
    process_list = []
    trial_index_list = []
    for i in range(batch_trial_num):
        parameters, trial_index = ax_clinet.get_next_trial()
        # proper files are generated using parameters
        cmd = ("sbatch", file1, file2)
        process = subprocess.Popen(cmd, stdout=DEVNULL)
        process_list.append(process)
        trial_index_list.append(trial_index)

    for i in range(batch_trial_num):
        process = process_list[i]
        return_code = process.wait()
        if return_code == 0:
            score = collect_result(...)
            ax.complete_trial(trial_index=trial_idx_list[i], raw_data=score)
        else:
            ax.log_trial_failure(trial_index=trial_idx_list[i])
    return

strategy = GenerationStrategy(
    name="GP+NEI",
    steps=[
        GenerationStep(
            model=Models.SOBOL,
            num_arms=init_trial_num,
            model_kwargs={"scramble": True},
        ),
        GenerationStep(
            model=Models.GPEI,
            num_arms=-1, 
            model_kwargs={"transforms": [UnitX, StandardizeY]},
        ),
    ],
)
ax = AxClient(generation_strategy=strategy)
ax.create_experiment(name="test", parameters=parameters)

for i in range(num_trials):
    batch_evaluation_service(ax, 5)

I am mostly not sure whether the strategy I used in ax_client are similar to the ones I explicitly set before, i.e., like Max mentioned above, if it's a greedy version of the fully joint_optimize. Could you confirm or point out if I did anything improper? Also I would appreciate if you could point out if I should change any of the steps I listed above. Thank you!

One more follow up: do you have plans to enable a full joint_optimization in AxClient or do you consider the current approach good enough? With the current behavior, it seems to me that there are less and less differences between arm and trial. Multiple arms within a trial (in the case of Experiments) are the same as a bunch of single-arm trials. Is my understanding correct?

lena-kashtelyan commented 5 years ago

Hello again, @dongyaoli!

1) Your usage of the Service API looks correct to me; one thing is, you likely don't need to specify the generation strategy yourself; Sobol+GPEI will be chosen under the hood anyway.

2) Will let @Balandat answer regarding the modeling here.

3) What do you mean by full joint optimization here? Just to clarify, so that I can fully respond to your question regarding whether we plan to include it in the Service API (AxClient).

4) There is and always will be a difference between an arm an a trial, because an arm essentially specifies parameterization, whereas a trial specifies a run of that parameterization. Therefore, two different trials can contain the same arm, for example. And data is logged for a trial, not an arm; if there is non-stationarity in the problem setup, two different trials with the same arm may have different metric results. Does that make sense?

A BatchTrial (multiple-arm trial) is also different from multiple single-arm trials, as a BatchTrial is categorized by collective deployment; in other words, you deploy all five arms together and wait for all of them to complete before logging data for a batch. With single-arm trials, each trial is independent of the others and is deployed independently of other trials. So in your case, for example, multiple single-arm trials are appropriate, because you are running each trial in a separate process and separately logging data for each.

Balandat commented 5 years ago

The default behavior of Ax is to use a greedy approach (that performs sequential conditioning on previously selected points) using BoTorch's sequential_optimize function. You can change that behavior by passing model_gen_kwargs = {"optimizer_kwargs": {"joint_optimize": True}} into the GenerationStep.

dongyaoli commented 5 years ago

Thank you @lena-kashtelyan and @Balandat again for your help! I think I have enough understanding for now and am going to close this.