experimental-design / bofire

Experimental design and (multi-objective) bayesian optimization.
https://experimental-design.github.io/bofire/
BSD 3-Clause "New" or "Revised" License
222 stars 23 forks source link

Classification surrogates #297

Closed gmancino closed 8 months ago

gmancino commented 1 year ago

Add Classification Models for Surrogates

Adding new surrogates to allow for classification of output values (e.g. 'unacceptable', 'acceptable', or 'ideal') to use for modeling unknown constraints. Concretely, if $g{\theta}:\mathbb{R}^d\to[0,1]^c$ represents a function governed by learnable parameters $\theta$ which outputs a probability vector over $c$ potential classes (i.e. for input $x\in\mathbb{R}^d$, $g{\theta}(x)^\top\mathbf{1}=1$ where $\mathbf{1}$ is the vector of all 1's) and we have acceptability criteria for the corresponding classes given by $a\in{0,1}^c$, we can compute the expected acceptability as scalar output via $g_{\theta}(x)^\top a\in[0,1]$ which can be passed in as a constrained objective function.

Classification Models

We add a new objective function in 'bofire/data_models/objectives/categorical.py' which can be passed to outputs of type CategoricalOutput. These are instantiated with a list of probability scales (i.e. the acceptability criteria vector) via the desirability argument and inherit the categories from the corresponding output. Using this new objective:

  1. We implement an MLP ensemble method to start ('bofire/data_models/surrogates/mlp.py') which outputs a $c$-dimensional probability vector for each datapoint
  2. The predicted value (stored in {key}_{class}_prob of the predictions) is the argmax along this probability vector, while the objective value (stored in {key}_{class}_{des} of the predictions) is the inner-product of the probability vector with the acceptability criteria vector
    • This value is also computed in the constrained_objective2botorch function, which currently undergoes the inverse of the sigmoid transformation to maintain the value in the probability space
  3. We pass in the objective value to BoTorch as a constraint
gmancino commented 1 year ago

@jduerholt, please let me know if the style of the updates here are appropriate. I will build more models if the initial idea makes sense.

gmancino commented 1 year ago

@jduerholt thank you so much for your feedback! I think I have addressed all of your concerns, but please let me know if there are any additional errors. If it all looks good, I plan on implementing something like in https://github.com/pytorch/botorch/issues/640 for a categorical GP. This should cover our initial basis :)

gmancino commented 10 months ago

Hi @jduerholt, I have made the corresponding changes to this PR which we have discussed. Some tests are currently failing because I changed the TCategoryVals to be of type Tuple[str, ...], which would make the categories immutable. Hence, in each hard coded location of categories (which are currently lists) the tests fail. I know we discussed using Tuples instead of Lists for this so it may be worth changing by hand unless you have some other opinion? There may be additional technical comments on the changes made which need to be addressed first ;)

jduerholt commented 9 months ago

Regarding you failing test, something went wrong in one of your merges against main, in main the line with the correct terms looks as following: https://github.com/experimental-design/bofire/blob/bf2a11999a7ba779a758c70ee17a555213676d6b/tests/bofire/strategies/doe/test_utils.py#L77

In your branch it looks likes this: https://github.com/experimental-design/bofire/blob/20ec136ffdf8bf662c9f332aac2f25bdfbf85116/tests/bofire/strategies/doe/test_utils.py#L77. In main we updated it at some point to the version looking like this "x ** 2" which is the new formulaic format, whereas you have in your branch still "x**2" which is the old format.

I assume some problems with a merge.

gmancino commented 9 months ago

@jduerholt This is ready for another round of reviews :) Hopefully the last one!

gmancino commented 8 months ago

@jduerholt, I have updated all of the changes previously requested (including tests). Please let me know what you think once you have some time :)

gmancino commented 8 months ago

@jduerholt the request changes are complete :)