brandontrabucco / design-bench

Benchmarks for Model-Based Optimization
MIT License
80 stars 19 forks source link

Create New Task and Run Baseline Models #4

Closed bonaventuredossou closed 2 years ago

bonaventuredossou commented 2 years ago

How to create a new task (on a new local dataset) to run the baseline models reported on the paper?

brandontrabucco commented 2 years ago

Hi bonaventure,

For creating a new MBO task using a continuous-valued dataset, you can look at this snippet, taken from the readme:

from design_bench.datasets.continuous_dataset import ContinuousDataset
import design_bench
import numpy as np

# define a custom dataset subclass of ContinuousDataset
class QuadraticDataset(ContinuousDataset):

    def __init__(self, **kwargs):

        # define a set of inputs and outputs of a quadratic function
        x = np.random.normal(0.0, 1.0, (5000, 7))
        y = (x ** 2).sum(keepdims=True)

        # pass inputs and outputs to the base class
        super(QuadraticDataset, self).__init__(x, y, **kwargs)

# parameters used for building the validation set
split_kwargs=dict(
    val_fraction=0.1,
    subset=None,
    shard_size=5000,
    to_disk=True,
    disk_target="quadratic/split",
    is_absolute=True)

# parameters used for building the model
model_kwargs=dict(
    hidden_size=512,
    activation='relu',
    num_layers=2,
    epochs=5,
    shuffle_buffer=5000,
    learning_rate=0.001)

# keyword arguments for building the dataset
dataset_kwargs=dict(
    max_samples=None,
    distribution=None,
    max_percentile=80,
    min_percentile=0)

# keyword arguments for training FullyConnected oracle
oracle_kwargs=dict(
    noise_std=0.0,
    max_samples=None,
    distribution=None,
    max_percentile=100,
    min_percentile=0,
    split_kwargs=split_kwargs,
    model_kwargs=model_kwargs)

# register the new dataset with design_bench
design_bench.register(
    'Quadratic-FullyConnected-v0', QuadraticDataset,
    'design_bench.oracles.tensorflow:FullyConnectedOracle',
    dataset_kwargs=dataset_kwargs, oracle_kwargs=oracle_kwargs)

# build the new task (and train a model)         
task = design_bench.make("Quadratic-FullyConnected-v0")

def solve_optimization_problem(x0, y0):
    return x0  # solve a model-based optimization problem

# evaluate the performance of the solution x_star
x_star = solve_optimization_problem(task.x, task.y)
y_star = task.predict(x_star)

Additionally, this example might help if you also want to implement an exact oracle function (the above example will use a neural-network approximate oracle):

https://github.com/brandontrabucco/design-bench/blob/new-api/design_bench/oracles/exact/hopper_controller_oracle.py

Let me know if the example does not work as expected!

Running a design-baseline algorithm on the new task can be done by specifying the appropriate new task_id when calling the function that runs the baseline you are interested in, such as:

https://github.com/brandontrabucco/design-baselines/blob/master/design_baselines/cbas/__init__.py#L15

Warm regards, Brandon

bonaventuredossou commented 2 years ago

I want typically just to reproduce COM introduced in your paper, on a custom dataset.

Does the snippet code you provided above do that? Please can you provide a direct way of doing this?

bonaventuredossou commented 2 years ago

Also, I have a "dumb" question. What does this do? "design_bench.oracles.tensorflow:FullyConnectedOracle" in the register function?

bonaventuredossou commented 2 years ago

@brandontrabucco No it did not work - I brought an error about the SmilesTokenizer

brandontrabucco commented 2 years ago

Its possible that rdkit or deepchem has been installed with an incompatible version.

Could you share the versions of packages in the python environment in which you are working on this. I have copied the versions of packages that I am using below, which is in the requirements.txt file in the design-baselines repository.

absl-py==0.12.0
aiohttp==3.6.2
aiohttp-cors==0.7.0
aioredis==1.3.1
argon2-cffi==20.1.0
astunparse==1.6.3
async-timeout==3.0.1
attrs==19.3.0
backcall==0.2.0
beautifulsoup4==4.9.1
biopython==1.78
bleach==3.1.5
blessings==1.7
boto3==1.16.19
botocore==1.19.19
botorch==0.3.3
brotlipy==0.7.0
cachetools==4.1.1
certifi==2020.6.20
chardet==3.0.4
click==7.1.2
cloudpickle==1.3.0
cma==3.0.3
colorama==0.4.3
colorful==0.5.4
cycler==0.10.0
Cython==0.29.21
decorator==4.4.2
deepchem==2.5.0
defusedxml==0.6.0
dm-tree==0.1.5
docker==4.3.0
entrypoints==0.3
fasteners==0.16
filelock==3.0.12
flatbuffers==1.12
future==0.18.2
gast==0.3.3
glfw==1.12.0
google==3.0.0
google-api-core==1.22.2
google-auth==1.21.3
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
googleapis-common-protos==1.52.0
gpustat==0.6.0
gpytorch==1.3.0
grpcio==1.32.0
gym==0.12.5
h5py==2.10.0
hiredis==1.1.0
huggingface-hub==0.0.8
imageio==2.9.0
importlib-metadata==1.7.0
ipdb==0.13.4
ipykernel==5.3.4
ipython==7.17.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.17.2
Jinja2==2.11.2
jmespath==0.10.0
joblib==1.0.0
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.6
jupyter-console==6.1.0
jupyter-core==4.6.3
Keras-Preprocessing==1.1.2
kiwisolver==1.2.0
lmdb==1.0.0
lockfile==0.12.2
Markdown==3.2.2
MarkupSafe==1.1.1
matplotlib==3.3.0
mistune==0.8.4
msgpack==1.0.0
mujoco-py==2.0.2.3
multidict==4.7.6
nbconvert==5.6.1
nbformat==5.0.7
notebook==6.1.1
numpy==1.18.5
nvidia-ml-py3==7.352.0
oauthlib==3.1.0
opencensus==0.7.10
opencensus-context==0.1.1
opt-einsum==3.3.0
packaging==20.4
pandas==1.0.5
pandocfilters==1.4.2
parso==0.7.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.2.0
prometheus-client==0.8.0
prompt-toolkit==3.0.6
protobuf==3.17.1
psutil==5.7.2
ptyprocess==0.6.0
py-spy==0.3.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycosat==0.6.3
pyglet==1.5.0
Pygments==2.6.1
pyparsing==2.4.7
pyrsistent==0.16.0
python-dateutil==2.8.1
pytz==2020.1
PyYAML==5.3.1
pyzmq==19.0.2
qtconsole==4.7.5
QtPy==1.9.0
ray==1.3.0
redis==3.5.3
regex==2021.4.4
requests-oauthlib==1.3.0
robel==0.1.2
rsa==4.6
ruamel.yaml==0.15.87
s3transfer==0.3.3
sacremoses==0.0.45
scikit-learn==0.23.1
scikit-video==1.1.11
scipy==1.6.0
seaborn==0.11.0
Send2Trash==1.5.0
soupsieve==2.0.1
tabulate==0.8.7
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.7.0
tensorboardX==2.1
tensorflow==2.3.2
tensorflow-estimator==2.3.0
tensorflow-probability==0.11.0
termcolor==1.1.0
terminado==0.8.3
testpath==0.4.4
threadpoolctl==2.1.0
tokenizers==0.9.3
torch==1.7.1
torchvision==0.8.2
tornado==6.0.4
tqdm==4.56.0
traitlets==4.3.3
transformers==3.5.1
transforms3d==0.3.1
typing-extensions==3.7.4.3
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==0.57.0
Werkzeug==1.0.1
widgetsnbextension==3.5.1
wrapt==1.12.1
yarl==1.4.2
zipp==3.1.0
design-bench[all]==2.0.20
morphing-agents==1.5.1

Also, I have a "dumb" question. What does this do? "design_bench.oracles.tensorflow:FullyConnectedOracle" in the register function?

The FullyConnectedOracle is a class that trains a neural network to approximate the f(x) -> y mapping on an MBO dataset, which can be necessary if there isn't a way to obtain ground truth labels for new x values that aren't present in the original dataset. But, if there is a way to get ground truth labels, that would typically be preferred over a model.

To your final question, you can run COMs on a new task by calling the function 'coms_cleaned.callback' in a python script after registering a new Design-Bench Task with your custom dataset, and passing in the new task id to 'coms_cleaned.callback', and setting hyperparameters to appropriate values.

Something like this:

import design_bench
from design_baselines.coms_cleaned import coms_cleaned

# register a new design_bench Task
design_bench.register( ... )

coms_cleaned.callback( ... )
brandontrabucco commented 2 years ago

The reason for 'coms_cleaned.callback' rather than just 'coms_cleaned' is that the original 'coms_cleaned' function is wrapped in a 'click' command line interface, which 'coms_cleaned.callback' will bypass.