pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.08k stars 397 forks source link

[Feature Request] Trainable parametric outcome transforms #1174

Open sdaulton opened 2 years ago

sdaulton commented 2 years ago

🚀 Feature Request

Currently input transforms can be parametric and the parameters can be optimized jointly with the other hyperparameters of the GP (e.g. input warping does). Outcome transforms on the other hand are applied once upon initialization and therefore one cannot currently implement a parametric outcome transform and optimize its parameters jointly with the GP hyperparameters.

Motivation

Support new parametric outcome transforms.

Pitch

Support for inferring parametric outcome transforms similar to input transforms.

Describe alternatives you've considered

None

Are you willing to open a pull request? (See CONTRIBUTING)

When time allows...

cc @Balandat @dme65 @saitcakmak

sdaulton commented 2 years ago

cc @bajgar

saitcakmak commented 2 years ago

Interesting idea. Would we need to move the application of the outcome transform into the forward call then? forward doesn't really deal with the train outcomes, so maybe somewhere else where we actually compute the training loss.

I just exported #1176, which proposes a refactor of how input transforms are applied. We had some discussions around it internally, it'd be nice to get more feedback. TLDR: It proposes to make the current forward methods into a private _forward, (and posterior into a private _posterior), define the public forward in Model and apply the input transforms there, eliminating the need to deal with transforms at every model. To make this work with one-to-many transforms, the idea is to make those into their own class, separate from the other input transforms and only apply those in posterior call.

Balandat commented 2 years ago

Yeah, I think it would be great to have that. I do concur with @saitcakmak on this not being a straightforward extension of what we do for the input transforms.