facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.37k stars 308 forks source link

Plotting functions do not work with HierarchicalSearchSpace #2481

Closed SBlokhuizen closed 3 months ago

SBlokhuizen commented 4 months ago

I am experimenting with an hierarchical search space, and I noticed that the plotting functions, such as contour and slice plots do not work. Below is a minimal reproducible example of the code that triggers the error. By simply replacing the HierarchicalSearchSpace with a regular SearchSpace, the plots render without issue.

Is plotting not supported for hierarchical search spaces? Thanks!

import random

import pandas as pd
from ax import (
    ChoiceParameter,
    Data,
    Experiment,
    Metric,
    Objective,
    OptimizationConfig,
    ParameterType,
    RangeParameter,
    SearchSpace,
)
from ax.core.runner import Runner
from ax.core.search_space import HierarchicalSearchSpace
from ax.modelbridge.registry import Models
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.slice import interact_slice, plot_slice
from ax.utils.common.result import Ok
from ax.utils.notebook.plotting import render

class MyMetric(Metric):
    def __init__(self, name):
        super().__init__(name=name)

    def fetch_trial_data(self, trial):
        records = []
        for arm_name, arm in trial.arms_by_name.items():
            records.append(
                {
                    "arm_name": arm_name,
                    "metric_name": self.name,
                    "trial_index": trial.index,
                    "mean": random.random(),
                    "sem": 0.0,
                }
            )
        return Ok(value=Data(df=pd.DataFrame.from_records(records)))

search_space = HierarchicalSearchSpace(
    parameters=[
        # Parameter B
        ChoiceParameter(
            name="B",
            parameter_type=ParameterType.STRING,
            values=["B1", "B2"],
            is_ordered=False,
            sort_values=False,
            dependents={
                "B1": ["B1a", "B1b"],
                "B2": ["B2a", "B2b"],
            },
        ),
        # Parameters dependent on B1
        RangeParameter(
            name="B1a",
            parameter_type=ParameterType.FLOAT,
            lower=0,
            upper=5,
        ),
        RangeParameter(
            name="B1b",
            parameter_type=ParameterType.FLOAT,
            lower=1,
            upper=6,
        ),
        # Parameters dependent on B2
        RangeParameter(
            name="B2a",
            parameter_type=ParameterType.FLOAT,
            lower=1,
            upper=7,
        ),
        RangeParameter(
            name="B2b",
            parameter_type=ParameterType.FLOAT,
            lower=1,
            upper=8,
        ),
    ]
)

class MyRunner(Runner):
    def run(self, trial):
        return {"name": str(trial.index)}

metric_name = "my_metric"
optimization_config = OptimizationConfig(
    objective=Objective(
        metric=MyMetric(metric_name),
        minimize=True,
    ),
)

exp = Experiment(
    name="my_exp",
    search_space=search_space,
    optimization_config=optimization_config,
    runner=MyRunner(),
)

NUM_SOBOL_TRIALS = 5
NUM_BOTORCH_TRIALS = 10

sobol = Models.SOBOL(
    search_space=exp.search_space,
    optimization_config=optimization_config,
)

for i in range(NUM_SOBOL_TRIALS):
    print(f"Running sobol trial {i + 1}/{NUM_SOBOL_TRIALS}...")
    generator_run = sobol.gen(n=1)
    trial = exp.new_trial(generator_run=generator_run)
    trial.run()
    print(generator_run.arms[0].parameters)

    trial.mark_completed()

for i in range(NUM_BOTORCH_TRIALS):
    print(
        f"Running BO trial {i + NUM_SOBOL_TRIALS + 1}/{NUM_SOBOL_TRIALS + NUM_BOTORCH_TRIALS}..."
    )
    gpei = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data())
    generator_run = gpei.gen(n=1)
    trial = exp.new_trial(generator_run=generator_run)
    trial.run()
    print(generator_run.arms[0].parameters)

    trial.mark_completed()
model = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data())

# None of these work
render(interact_contour(model, metric_name))
render(plot_slice(model, param_name="B1a", metric_name=metric_name))
render(interact_slice(model))
render(plot_contour(model, param_x="B1a", param_y="B1b", metric_name=metric_name))
mpolson64 commented 3 months ago

@SBlokhuizen thanks for bringing this to our attention. We are currently in the process of a larger reworking of our plotting functionality which we hope will address issues like this and many others.

For now, I will close out this task and add a link to it on our "wishlist" master task #566. We hope to have something to share with you soon.