pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.07k stars 390 forks source link

[Bug] The optimizer is unable to optimize MOMF due to (numerically) zero acquisition values #2121

Closed vlad451101 closed 8 months ago

vlad451101 commented 10 months ago

🐛 Bug

I have an optimization problem that consists of 4 input and 5 output parameters. It takes approximately 3 hours to evaluate one sample of the true function.

I have successfully implemented this problem, into the Multi-objective Bayesian optimization, where I had no problems whatsoever, everything worked just fine, and most importantly, the optimization of the acquisition function took a few seconds or a few minutes at most.

I wanted to try Multi-objective multi-fidelity optimization, for possible speed up of the Bayesian optimization of my problem, however in my case, I am not able to get any new candidates and the MOMF optimization of the acquisition function is not able to complete. When I ran the optimization, I didn't get new candidates even after a few hours. The optimization basically runs indefinitely without any sign of stopping.

I tried to change the number of initial training samples, change some settings of the acquisition function MOMF and optimize_acqf optimizer (RAW_SAMPLES, MC_SAMPLES, ..), etc... However, I didn't find any method that would help me to solve this problem.

I would greatly appreciate any advice or help in resolving this issue!

To reproduce

Below you will find my code that should be able to replicate the problem I described. For this test case, I have generated and evaluated some initial training data that you will find attached in the following text file (training_data.txt). The data in the text file represents the raw inputs and outputs that are sent to and evaluated in the true function.

## Libraries
########################################################################################################################
import warnings
import torch
import gpytorch
from botorch.exceptions import BadInitialCandidatesWarning, InputDataWarning
from botorch.utils.transforms import normalize, unnormalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls.sum_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_mll
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.models.gp_regression import SingleTaskGP
from botorch.utils.multi_objective.box_decompositions.non_dominated import FastNondominatedPartitioning
from botorch.acquisition.multi_objective.multi_fidelity import MOMF
from botorch.optim.optimize import optimize_acqf
import pandas as pd
from copy import deepcopy
warnings.filterwarnings("ignore", category=BadInitialCandidatesWarning)
warnings.filterwarnings("ignore", category=InputDataWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
########################################################################################################################

## All functions
########################################################################################################################
def cost_func(x: torch.tensor, tkwargs: dict):
    """A simple exponential cost function."""
    exp_arg = torch.tensor(2, **tkwargs)
    val = torch.exp(exp_arg * x)
    return val

def cost_callable(X: torch.Tensor) -> torch.Tensor:
    r"""Wrapper for the cost function that takes care of shaping
    input and output arrays for interfacing with cost_func.
    This is passed as a callable function to MOMF.

    Args:
        X: A `batch_shape x q x d`-dim Tensor
    Returns:
        Cost `batch_shape x q x m`-dim Tensor of cost generated
        from fidelity dimension using cost_func.
    """
    tkwargs = {  # Tkwargs is a dictionary contaning data about data type and data device
        "dtype": torch.double,
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    }
    cost = cost_func(torch.flatten(X), tkwargs).reshape(X.shape)
    cost = cost[..., [-1]]
    return cost

def getFidelityComponent(train_x: torch.Tensor, n_points: int, tkwargs: dict, fid_samples_n_points=1001) -> torch.Tensor:
    """
    Generates training data with Fidelity dimension sampled
    from a probability distribution that depends on Cost function
    """
    # Array from which fidelity values are sampled
    fid_samples = torch.linspace(0, 1, fid_samples_n_points, **tkwargs)
    # Probability calculated from the Cost function
    prob = 1 / cost_func(fid_samples, tkwargs)
    # Normalizing
    prob = prob / torch.sum(prob)
    # Generating indices to choose fidelity samples
    idx = prob.multinomial(num_samples=n_points, replacement=True)
    train_x[:, -1] = fid_samples[idx]
    # Calls the objective wrapper to generate train_obj
    return train_x

def get_fidelity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Wrapper around the Objective function to take care of fid_obj stacking"""
    fid = 1 * x[..., -1]  # Getting the fidelity objective values
    fid_out = fid.unsqueeze(-1)
    # Concatenating objective values with fid_objective
    y_out = torch.cat([y, fid_out], -1)
    return y_out

def changeObjective(tensor, objectives):
    # Copy tensor
    newTensor = deepcopy(tensor)
    # Select columns
    for col,objective in enumerate(objectives):
        # Change absolute value if objective is minimized
        if objective == 'Minimize':
            # Extract the selected column
            column = newTensor[:, col]
            # Negate the values in the column
            negated_column = -column
            # Update the tensor with the negated column
            newTensor[:, col] = negated_column

    return newTensor

def initialize_model(train_x, train_obj):#, bounds):
    model = SingleTaskGP(train_x, train_obj,
                         covar_module=gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=len(train_x[0]))),
                         likelihood=gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive()),
                         outcome_transform=Standardize(m=train_obj.shape[-1]))

    # Create model list
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    return mll, model

def optimize_MOMF_and_get_obs(model: SingleTaskGP, train_obj: torch.Tensor, sampler: SobolQMCNormalSampler,
                              ref_points: torch.Tensor, standard_bounds: torch.Tensor, unnormalize_bounds: list,
                              BATCH_SIZE: int, tkwargs: dict, NUM_RESTARTS: int, RAW_SAMPLES: int): # cost_call: Callable[[torch.Tensor], torch.Tensor],
    """
    Wrapper to call MOMF and optimizes it in a sequential greedy
    fashion returning a new candidate and evaluation
    """
    partitioning = FastNondominatedPartitioning(ref_point=torch.tensor(ref_points, **tkwargs), Y=train_obj)

    acq_func = MOMF(
        model=model,
        ref_point=ref_points,  # use known reference point
        partitioning=partitioning,
        sampler=sampler,
        cost_call=cost_callable,
        sample_around_best  =True,
    )
    # Optimization
    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=standard_bounds,
        q=BATCH_SIZE,
        num_restarts=NUM_RESTARTS,
        raw_samples=RAW_SAMPLES,  # used for intialization heuristic
        options={"batch_limit": 5, "maxiter": 200, "nonnegative": True},
        sequential=True,
    )
    # observe new values
    newCandidates = unnormalize(candidates.detach(), bounds=torch.tensor(unnormalize_bounds, **tkwargs))
    return newCandidates
########################################################################################################################

## Inputs and others
########################################################################################################################
# Path to Excel file
pathResults = r'training_data.txt'

# Selected input variables for Bayesian optimization loop
inputs = ['input1', 'input2', 'input3', 'input4']

# Selected input variables for Bayesian optimization loop
outputs = ['output1', 'output2', 'output3', 'output4', 'output5']

# Dimensions for MOMF optimization
dim_xMO = len(inputs)  # Input Dimension for MOMF optimization
dim_yMO = len(outputs)  # Output Dimension for MO-only optimization
# Dimensions for MO-only optimization
dim_x = dim_xMO + 1  # Input Dimension for MO-only optimization
dim_y = dim_yMO + 1  # Output Dimesnion for MOMF optimization

# Output optimization objectives
objectives = ['Maximize', 'Maximize', 'Maximize', 'Minimize', 'Maximize']

# Kwargs for torch values
tkwargs = {
    "dtype": torch.double,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
}

# Selected bounds for variables
boundsMO = torch.tensor([[120, 9, 0.6, 23.333333333],
                            [240, 21, 2.7, 46.666666667]], **tkwargs) # For MO only
bounds = torch.tensor([[120, 9, 0.6, 23.333333333, 0],
                            [240, 21, 2.7, 46.666666667, 1]], **tkwargs) # For MOMF

# Bounds for MOMF optimization
standard_bounds = torch.tensor([[0.0] * dim_x, [1.0] * dim_x], **tkwargs)

# Selected reference points for objectives
ref_points = [4.2, 68, 0.43, 0, 0.7, 0] # For MOMF Hypervolume calculation

# Bayesian optimization loop variables
INIT_SIZE = 100 # Number of randomly generated paralel calculactions for initial training data
BATCH_SIZE = 19 # Number of parallel calculations for Bayesian optimization loop
NUM_RESTARTS = 10 # Number of restarts of acquisition optimizer
RAW_SAMPLES = 256 # Number of samples for acquisition optimizer
MC_SAMPLES = 128 # Number of samples for monte-carlo samplerer
########################################################################################################################

## Optimization loop
########################################################################################################################
# Get all data
data = pd.read_csv(pathResults, sep=' ', header=0)
# Generate random fidelity data
train_fid = torch.rand(INIT_SIZE, 1, **tkwargs)
# Get together input data and fidelity component
train_x = torch.cat((normalize(torch.tensor(data[inputs].values, **tkwargs), boundsMO), train_fid), dim=1)
# Get the fidelity component for initial input training data
train_x = getFidelityComponent(train_x, INIT_SIZE, tkwargs)
# Calls the objective wrapper to load train_obj and get the fidelity - normaly hear I evaluate the true function
train_obj_no_fid = torch.tensor(data[outputs].values, **tkwargs)
train_obj = get_fidelity(train_x, train_obj_no_fid)
# Change absolute value of training data - for optimization purposes
train_obj_C = changeObjective(train_obj, objectives)

# Initilize GP models on random data
mll, model = initialize_model(train_x, train_obj_C)

# Fit the model
fit_gpytorch_mll(mll)

# Get objective for optimization acquisition functions
train_opt = torch.cat([train_obj_C, torch.zeros(BATCH_SIZE, dim_y, **tkwargs)])

# Generate Sampler
momf_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([MC_SAMPLES]))

# optimize acquisition functions and get new observations
new_x = optimize_MOMF_and_get_obs(
    model=model,
    train_obj=train_opt,
    sampler=momf_sampler,
    ref_points=ref_points,
    standard_bounds=standard_bounds,
    unnormalize_bounds=bounds,
    BATCH_SIZE=BATCH_SIZE,
    tkwargs=tkwargs,
    NUM_RESTARTS=NUM_RESTARTS,
    RAW_SAMPLES=RAW_SAMPLES
)
########################################################################################################################

Expected Behavior

The optimizer optimize_acqf should be able to optimize the input candidates of the MOMF acquisition function and return them as fast as possible.

System information

Additional context

eytan commented 10 months ago

You may wish to try out Hypervolume knowledge gradient instead. Here is a tutorial on using it for MF MOBO:

https://github.com/pytorch/botorch/pull/2101

On Sun, Nov 19, 2023 at 9:23 AM vlad451101 @.***> wrote:

🐛 Bug

I have an optimization problem that consists of 4 input and 5 output parameters. It takes approximately 3 hours to evaluate one sample of the true function.

I have successfully implemented this problem, into the Multi-objective Bayesian optimization https://botorch.org/tutorials/multi_objective_bo, where I had no problems whatsoever, everything worked just fine, and most importantly, the optimization of the acquisition function took a few seconds or a few minutes at most.

I wanted to try Multi-objective multi-fidelity optimization https://botorch.org/tutorials/Multi_objective_multi_fidelity_BO, for possible speed up of the Bayesian optimization of my problem, however in my case, I am not able to get any new candidates and the MOMF optimization of the acquisition function is not able to complete. When I ran the optimization, I didn't get new candidates even after a few hours. The optimization basically runs indefinitely without any sign of stopping.

I tried to change the number of initial training samples, change some settings of the acquisition function MOMF and optimize_acqf optimizer ( RAW_SAMPLES, MC_SAMPLES, ..), etc... However, I didn't find any method that would help me to solve this problem.

I would greatly appreciate any advice or help in resolving this issue! To reproduce

Below you will find my code that should be able to replicate the problem I described. For this test case, I have generated and evaluated some initial training data that you will find attached in the following text file ( training_data.txt https://github.com/pytorch/botorch/files/13403434/training_data.txt). The data in the text file represents the raw inputs and outputs that are sent to and evaluated in the true function.

Libraries########################################################################################################################import warningsimport torchimport gpytorchfrom botorch.exceptions import BadInitialCandidatesWarning, InputDataWarningfrom botorch.utils.transforms import normalize, unnormalizefrom botorch.models.transforms.outcome import Standardizefrom gpytorch.mlls.sum_marginal_log_likelihood import ExactMarginalLogLikelihoodfrom botorch import fit_gpytorch_mllfrom botorch.sampling.normal import SobolQMCNormalSamplerfrom botorch.models.gp_regression import SingleTaskGPfrom botorch.utils.multi_objective.box_decompositions.non_dominated import FastNondominatedPartitioningfrom botorch.acquisition.multi_objective.multi_fidelity import MOMFfrom botorch.optim.optimize import optimize_acqfimport pandas as pdfrom copy import deepcopywarnings.filterwarnings("ignore", category=BadInitialCandidatesWarning)warnings.filterwarnings("ignore", category=InputDataWarning)warnings.filterwarnings("ignore", category=RuntimeWarning)warnings.filterwarnings("ignore", category=FutureWarning)

All functions########################################################################################################################def cost_func(x: torch.tensor, tkwargs: dict):

"""A simple exponential cost function."""
exp_arg = torch.tensor(2, **tkwargs)
val = torch.exp(exp_arg * x)
return val

def cost_callable(X: torch.Tensor) -> torch.Tensor: r"""Wrapper for the cost function that takes care of shaping input and output arrays for interfacing with cost_func. This is passed as a callable function to MOMF. Args: X: A batch_shape x q x d-dim Tensor Returns: Cost batch_shape x q x m-dim Tensor of cost generated from fidelity dimension using cost_func. """ tkwargs = { # Tkwargs is a dictionary contaning data about data type and data device "dtype": torch.double, "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), } cost = cost_func(torch.flatten(X), tkwargs).reshape(X.shape) cost = cost[..., [-1]] return cost def getFidelityComponent(train_x: torch.Tensor, n_points: int, tkwargs: dict, fid_samples_n_points=1001) -> torch.Tensor: """ Generates training data with Fidelity dimension sampled from a probability distribution that depends on Cost function """

Array from which fidelity values are sampled

fid_samples = torch.linspace(0, 1, fid_samples_n_points, **tkwargs)
# Probability calculated from the Cost function
prob = 1 / cost_func(fid_samples, tkwargs)
# Normalizing
prob = prob / torch.sum(prob)
# Generating indices to choose fidelity samples
idx = prob.multinomial(num_samples=n_points, replacement=True)
train_x[:, -1] = fid_samples[idx]
# Calls the objective wrapper to generate train_obj
return train_x

def get_fidelity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Wrapper around the Objective function to take care of fid_obj stacking""" fid = 1 * x[..., -1] # Getting the fidelity objective values fid_out = fid.unsqueeze(-1)

Concatenating objective values with fid_objective

y_out = torch.cat([y, fid_out], -1)
return y_out

def changeObjective(tensor, objectives):

Copy tensor

newTensor = deepcopy(tensor)
# Select columns
for col,objective in enumerate(objectives):
    # Change absolute value if objective is minimized
    if objective == 'Minimize':
        # Extract the selected column
        column = newTensor[:, col]
        # Negate the values in the column
        negated_column = -column
        # Update the tensor with the negated column
        newTensor[:, col] = negated_column

return newTensor

def initialize_model(train_x, train_obj):#, bounds): model = SingleTaskGP(train_x, train_obj, covar_module=gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=len(train_x[0]))), likelihood=gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive()), outcome_transform=Standardize(m=train_obj.shape[-1]))

# Create model list
mll = ExactMarginalLogLikelihood(model.likelihood, model)
return mll, model

def optimize_MOMF_and_get_obs(model: SingleTaskGP, train_obj: torch.Tensor, sampler: SobolQMCNormalSampler, ref_points: torch.Tensor, standard_bounds: torch.Tensor, unnormalize_bounds: list, BATCH_SIZE: int, tkwargs: dict, NUM_RESTARTS: int, RAW_SAMPLES: int): # cost_call: Callable[[torch.Tensor], torch.Tensor], """ Wrapper to call MOMF and optimizes it in a sequential greedy fashion returning a new candidate and evaluation """ partitioning = FastNondominatedPartitioning(ref_point=torch.tensor(ref_points, **tkwargs), Y=train_obj)

acq_func = MOMF(
    model=model,
    ref_point=ref_points,  # use known reference point
    partitioning=partitioning,
    sampler=sampler,
    cost_call=cost_callable,
    sample_around_best  =True,
)
# Optimization
candidates, _ = optimize_acqf(
    acq_function=acq_func,
    bounds=standard_bounds,
    q=BATCH_SIZE,
    num_restarts=NUM_RESTARTS,
    raw_samples=RAW_SAMPLES,  # used for intialization heuristic
    options={"batch_limit": 5, "maxiter": 200, "nonnegative": True},
    sequential=True,
)
# observe new values
newCandidates = unnormalize(candidates.detach(), bounds=torch.tensor(unnormalize_bounds, **tkwargs))
return newCandidates########################################################################################################################

Inputs and others######################################################################################################################### Path to Excel filepathResults = r'training_data.txt'

Selected input variables for Bayesian optimization loopinputs = ['input1', 'input2', 'input3', 'input4']

Selected input variables for Bayesian optimization loopoutputs = ['output1', 'output2', 'output3', 'output4', 'output5']

Dimensions for MOMF optimizationdim_xMO = len(inputs) # Input Dimension for MOMF optimizationdim_yMO = len(outputs) # Output Dimension for MO-only optimization# Dimensions for MO-only optimizationdim_x = dim_xMO + 1 # Input Dimension for MO-only optimizationdim_y = dim_yMO + 1 # Output Dimesnion for MOMF optimization

Output optimization objectivesobjectives = ['Maximize', 'Maximize', 'Maximize', 'Minimize', 'Maximize']

Kwargs for torch valuestkwargs = {

"dtype": torch.double,
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

}

Selected bounds for variablesboundsMO = torch.tensor([[120, 9, 0.6, 23.333333333],

                        [240, 21, 2.7, 46.666666667]], **tkwargs) # For MO onlybounds = torch.tensor([[120, 9, 0.6, 23.333333333, 0],
                        [240, 21, 2.7, 46.666666667, 1]], **tkwargs) # For MOMF

Bounds for MOMF optimizationstandard_bounds = torch.tensor([[0.0] dim_x, [1.0] dim_x], **tkwargs)

Selected reference points for objectivesref_points = [4.2, 68, 0.43, 0, 0.7, 0] # For MOMF Hypervolume calculation

Bayesian optimization loop variablesINIT_SIZE = 100 # Number of randomly generated paralel calculactions for initial training dataBATCH_SIZE = 19 # Number of parallel calculations for Bayesian optimization loopNUM_RESTARTS = 10 # Number of restarts of acquisition optimizerRAW_SAMPLES = 256 # Number of samples for acquisition optimizerMC_SAMPLES = 128 # Number of samples for monte-carlo samplerer

Optimization loop######################################################################################################################### Get all datadata = pd.read_csv(pathResults, sep=' ', header=0)# Generate random fidelity datatrain_fid = torch.rand(INIT_SIZE, 1, tkwargs)# Get together input data and fidelity componenttrain_x = torch.cat((normalize(torch.tensor(data[inputs].values, tkwargs), boundsMO), train_fid), dim=1)# Get the fidelity component for initial input training datatrain_x = getFidelityComponent(train_x, INIT_SIZE, tkwargs)# Calls the objective wrapper to load train_obj and get the fidelity - normaly hear I evaluate the true functiontrain_obj_no_fid = torch.tensor(data[outputs].values, **tkwargs)train_obj = get_fidelity(train_x, train_obj_no_fid)# Change absolute value of training data - for optimization purposestrain_obj_C = changeObjective(train_obj, objectives)

Initilize GP models on random datamll, model = initialize_model(train_x, train_obj_C)

Fit the modelfit_gpytorch_mll(mll)

Get objective for optimization acquisition functionstrain_opt = torch.cat([train_obj_C, torch.zeros(BATCH_SIZE, dim_y, **tkwargs)])

Generate Samplermomf_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([MC_SAMPLES]))

optimize acquisition functions and get new observationsnew_x = optimize_MOMF_and_get_obs(

model=model,
train_obj=train_opt,
sampler=momf_sampler,
ref_points=ref_points,
standard_bounds=standard_bounds,
unnormalize_bounds=bounds,
BATCH_SIZE=BATCH_SIZE,
tkwargs=tkwargs,
NUM_RESTARTS=NUM_RESTARTS,
RAW_SAMPLES=RAW_SAMPLES

)########################################################################################################################

Expected Behavior

The optimizer optimize_acqf should be able to optimize the input candidates of the MOMF acquisition function and return them as fast as possible. System information

  • BoTorch Version - 0.9.4
  • GPyTorch Version - 1.11
  • PyTorch Version - 2.0.1+cpu
  • Computer OS - Windows 10

Additional context

— Reply to this email directly, view it on GitHub https://github.com/pytorch/botorch/issues/2121, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAW34JJPVA2XOF7SOX3QVLYFII6ZAVCNFSM6AAAAAA7RYYTWGVHI2DSMVQWIX3LMV43ASLTON2WKOZSGAYDAOBXGIZDMNQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

vlad451101 commented 10 months ago

Thank you very much for your suggestion. This solution works like a charm and I really appreciate your help.

For the test case example I gave above, the acquisition function can optimize itself and only takes a few seconds. I have yet to try the full Bayesian optimization with multiple loops, but I don't think that will be a problem.

I just have a small suggestion for the improvement of the tutorial you gave me. In the case of evaluating Pareto front at the highest fidelity using NSGA-II (get_pareto function), loading the following functions from the pymoo library does not work:

from pymoo.algorithms.nsga2 import NSGA2
from pymoo.model.problem import Problem
from pymoo.util.termination.max_gen import MaximumGenerationTermination

I don't know what version you are using, but for pymoo version 0.6.1, the following lines need to be loaded in order for it to work:

from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.core.problem import Problem
from pymoo.termination.max_gen import MaximumGenerationTermination
Balandat commented 10 months ago

Thanks - yes pymoo 0.6+ brought with it some breaking changes; we can update the tutorial to be compatible with those.

sdaulton commented 10 months ago

Glad MF-HVKG is working well for you. Regarding MOMF being slow here, your problem has 4 objectives + 1 for MOMF's fidelity objectives. Since hypervolume has super-polynomial complexity with respect to the number of objectives, adding MOMF's fifth objective will cause computing and optimizing the acquisition function (qEHVI) over 5 objectives to be very expensive. MF-HVKG operates in the 4-objective space. Another reason I see why this would be slow is that you are using qEHVI (which MOMF uses) with a batch size of 19. qEHVI has exponential complexity with respect to the batch size, so using a batch size of 19 with 5 objectives would be infeasibly slow. MF-HVKG supports batch BO using fantasization, and only compute hypervolumes of a small finite Pareto set (not the batch of candidates) so it will be much faster. To speed up MOMF, one would want to implement a version of MOMF that uses qNEHVI, which is the same as acquisition function as qEHVI, but is incurs only polynomial complexity with respect to the batch size. (https://proceedings.neurips.cc/paper/2021/hash/11704817e347269b7254e744b5e22dac-Abstract.html)

esantorella commented 10 months ago

Thanks for the bug report and the repro! MOMF isn't converging because the hypervolume gain is (numerically) zero. MOMF is giving zero acquisition values and gradients at all random points I checked, and the optimizer makes no progress at all. This is a numerical issue; by switching MOMF to use qLogExpectedHypervolumeImprovement instead of qExpectedHypervolumeImprovement, I can see that hv_gain is positive but tiny. With that substitution, acquisition values are nonzero and the optimizer does make progress. However, it's still quite slow and I didn't have the patience to wait for optimize_acqf to finish.

Are there any blockers to cutting MOMF over to qLogExpectedHypervolumeImprovement @SebastianAment ? I can put in a PR for this.

vlad451101 commented 10 months ago

@sdaulton Thank you very much for the clarification.

The main reason I use a higher batch volume is that it's very convenient in my case and I want to get as much data as possible in one loop. Ideally, I would be in favor of even a higher batch volume, however, I am limited by my PC setup which cannot evaluate the true function for more than 19 samples in parallel.

SebastianAment commented 10 months ago

Are there any blockers to cutting MOMF over to qLogExpectedHypervolumeImprovement @SebastianAment ? I can put in a PR for this.

That's a great idea, two thoughts to this end: 1) We could also add the option to use qLogNEHVI, which would both remedy the zero acquisition values and as @sdaulton pointed out, the exponential scaling, and 2) this would also require changing the cost-weighting to work in log-space, but this should be reasonably straightforward.

eytan commented 10 months ago

This could be done but is there any reason to prefer MOMF over HVKG?

On Mon, Nov 20, 2023 at 4:05 PM Sebastian Ament @.***> wrote:

Are there any blockers to cutting MOMF over to qLogExpectedHypervolumeImprovement @SebastianAment https://github.com/SebastianAment ? I can put in a PR for this.

That's a great idea, two thoughts to this end:

  1. We could also add the option to use qLogNEHVI, which would both remedy the zero acquisition values and as @sdaulton https://github.com/sdaulton pointed out, the exponential scaling, and
  2. this would also require changing the cost-weighting to work in log-space, but this should be reasonably straightforward.

— Reply to this email directly, view it on GitHub https://github.com/pytorch/botorch/issues/2121#issuecomment-1819879471, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAW34NFY3XYLDZYAIVS2ODYFPH27AVCNFSM6AAAAAA7RYYTWGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMJZHA3TSNBXGE . You are receiving this because you commented.Message ID: @.***>

sdaulton commented 10 months ago

This could be done but is there any reason to prefer MOMF over HVKG?

Only if candidate generation speed is a bottleneck

SebastianAment commented 8 months ago

Hi @vlad451101, were you able to resolve the issue with HVKG or by switching out the acquisition function in MOMF to qLog(N)EHVI?

vlad451101 commented 8 months ago

Hi @SebastianAment, yes, I have been able to resolve the issue. Both, HVKG and qLogNEVHI works perfectly.

Therefore, I close this issue. If someone finds anything new or related to this issue, feel free to reopen it.