secondmind-labs / trieste

A Bayesian optimization toolbox built on TensorFlow
Apache License 2.0
212 stars 42 forks source link

Local models and datasets #788

Closed khurram-ghani closed 7 months ago

khurram-ghani commented 9 months ago

Related issue(s)/PRs: #782

Summary

This PR adds support for local models and datasets. The following scenarios are supported:

The initial dataset can be a global dataset sampled from the whole search space. This data will be replicated to each of the regions on the first iteration and subsequently each region will have an associated local dataset. For batch-TR algorithm, the dataset for each region are filtered after each iteration to only contain the points in the region (but TREGO doesn't do this).

Note: this replication of initial data can potentially cause an issue when a global model is being used, as the points may be repeated. This will only be an issue if regions overlap and both contain that initial data-point (as filtering would otherwise remove duplicates). The main way to avoid this issue is to provide local initial datasets, instead of a global initial dataset.

The trust_region notebook contains a new temporary TEST section, just to show how local models can be used in the notebook. It is worth noting in the gif that the query-points are filtered to be only inside the boxes for each iteration. This section is only for demonstration and will be removed before merging this PR. A follow-on PR will update the TURBO section to use local models and demonstrate this functionality.

Fully backwards compatible: no

The BatchTrustRegion rule acquisition returns rank 3 points, instead of rank 2 as for other rules (and previous trust-region rules). This means the users should use the new batched observer with this rule. That is already taken care of in BayesianOptimizer. However, with AskTellOptimizer the users should use the batched observer as follows:

from trieste.objectives.utils import mk_batch_observer

observer = ...
batch_observer = mk_batch_observer(observer)

new_points = ask_tell.ask()
new_data = batch_observer(new_points)
ask_tell.tell(new_data)

PR checklist