pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.07k stars 392 forks source link

[Feature Request] Minibatch training when mll is an _ApproximateMarginalLogLikelihood #1438

Open jacobrgardner opened 2 years ago

jacobrgardner commented 2 years ago

🚀 Feature Request

Now that fit_gpytorch_mll exists using multiple dispatch, it seems like it'd be fairly straightforward to support minibatch training by registering a fit_gpytorch_torch_stochastic or similar as the optimizer for _ApproximateMarginalLogLikelihood mlls.

Motivation

Is your feature request related to a problem? Please describe. As far as I can tell browsing the code, running fit_gpytorch_mll on an ApproximateGPyTorchModel would just use full batch training. As a result, we have (e.g., for latent space optimization tasks) typically been brewing our own GPyTorch models + training code still, despite the existence of ApproximateGPyTorchModel. We're planning on submitting a PR with a latent space bayesopt tutorial, but I'd like it to be more BoTorch-y than it currently is -- right now the actual model handling is entirely outside of BoTorch for this reason.

Pitch

Describe the solution you'd like

Are you willing to open a pull request? (See CONTRIBUTING) Yeah

Balandat commented 2 years ago

This is a great suggestion, and IIRC @j-wilson has played around with this a bit before. Not sure what state that is in and whether it makes sense for him to push out a draft of this, or whether it's better to just start fresh with a PR on your end (seems reasonably straightforward all in all). @j-wilson any thoughts here?

One solution might be to just call the fallback fit if a minibatch size / optimizer isn't specified by the user? On the other hand, in the long run, it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically rather than just e.g. an inducing point kernel on an ExactGP.

Yeah that makes sense to me. I think if the user is using variational approximate GP models we can assume that they'd be able to manually specify the full batch training optimizer if needed. Another option would be to parse this somehow from the kwargs, but I don't think we need to worry about this for now.

jacobrgardner commented 2 years ago

Here's a rough draft of what this might look like: https://github.com/pytorch/botorch/compare/main...jacobrgardner:botorch:stochastic_fitting

The high level fitting works great (works on a piece of code I've been testing as well as our homebrew model fitting). Still a few TODOs even before code review:

@j-wilson @Balandat just let me know if you all don't have something further along already than this and I can open as a PR to track the TODOs there.

(Edit: Oops, some automated thing must have run black on the files before commiting, sorry about the irrelevant parts of the linked diff)

saitcakmak commented 2 years ago

IIRC, @mshvartsman et al. use ApproximateGP with full-batch training. cc'ing in case they have any input on this.

j-wilson commented 2 years ago

@jacobrgardner Hi Jake. Fully on board with you here. As Max mentioned, I've put together a draft for this as well. At a glance, it looks pretty similar to your implementation.

The main difference seems to be that I actually just rewrote fit_gpytorch_torch to be more generic instead of introducing a separate method. This isn't necessarily a better way of doing things; I just don't like the current fit_gpytorch_torch method...

Aside from that, I have data_loader as an optional argument, with the method defaulting to full-batch. Under this approach, the responsibility of constructing data_loader is off-loaded to the fit_gpytorch_mll subroutine. This same subroutine would be also responsible for throwing an MDNotImplementedError in cases where the amount of training data is sufficiently small for spicy.optimize to be the preferred optimizer.

Would something like this work for your use cases?

Regrading GPyTorchDataset, I'm not sure I understand the need for this class. How about:

dataset = TensorDataset(*model.train_inputs, model.train_targets)
data_loader = DataLoader(dataset, **kwargs)
for batch_idx, (*inputs, targets) in enumerate(data_loader):
    # do stuff

If we end up with cases where train_targets is also Tuple[Tensor, ...], we'd need to update the (*input, targets) bit, but this seems doable?

jacobrgardner commented 2 years ago

@j-wilson Ah, yeah looks like we can just use TensorDataset there.

In terms of the rest, how would do you envision the user specifying to use minibatch training? Would the idea be to do something like fit_gpytorch_mll(mll, my_data_loader), overriding the use of model.train_inputs entirely? Or would I specify the minibatch size, fit_gpytorch_mll would do some typechecking to make sure I'm using Approximate*, and then make a data loader?

I guess I'm personally fine with essentially any of the proposed interfaces here.

j-wilson commented 2 years ago

@jacobrgardner Good questions.

I hadn't actually considered a solution like fit_gpytorch_mll(mll, data_loader). I really like this API, but fear it may be too heavy-duty for simple use cases.

My thought had been to make a create_data_loader(model, **kwargs) -> DataLoader helper that abstracts away DataLoader construction. We would then add data_loader: Union[DataLoader, Dict[str, Any]] as a keyword to the MD subroutine, which would internally call create_data_loader(mll.model, **data_loader) when data_loader is passed as a dict.

A typical call pattern might then look something like:

fit_gpytorch_mll(mll, data_loader={"batch_size": 128})
jacobrgardner commented 2 years ago

@j-wilson Okay, so if you all think a rewrite of fit_gpytorch_torch is warranted, maybe the right solution here is something in the middle, where we add a top level _fit_approximategp_stochastic because then we can use the dispatcher to typecheck that the mll is an ApproximateMLL and the model is an approximate model.

Then, both _fit_approximategp_stochastic and _fit_fallback end up calling fit_gpytorch_torch (or scipy for the latter), but _fit_approximategp_stochastic enables batch size < N functionality, while _fit_fallback throws a warning if batch_size is user specified < N, with the warning saying the types didn't match well enough for the dispatcher, so we're doing full batch?

j-wilson commented 2 years ago

@jacobrgardner Up for discussion. A naive implementation would probably see data_loader as a _fit_approximategp_stochastic-specific keyword argument that gets ignored by other fit_gpytorch_mll subroutines.

mshvartsman commented 2 years ago

it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically

Not sure if that's a safe assumption :). As @saitcakmak said, in AEPsych we pretty much exclusively do full-batch small-data fit_gpytorch_model with ApproximateGP (mainly here https://github.com/facebookresearch/aepsych/blob/main/aepsych/models/base.py#L423). Most of what we use VI for is non-Gaussian likelihoods (bernoulli, categorical etc), not big data. It's not a huge deal for us to change our calls but in my experience, the scipy optimizer is dramatically faster when it works, and I wouldn't want to default for new users to be some flavor of SGD on small data. I don't think when I started using gpytorch/botorch I would've known to switch optimizers for our setting.

So my vote would be to either [a] retain the full batch with SAA default and warn if the data is too large, or [b] have a sensible user-adjustable cutoff to switch between the fitting strategies (similarly to how gpytorch switches between cholesky and CG for solves and logdets). I think I'd prefer [b] over [a], we'd just need to tune the cutoff.

j-wilson commented 2 years ago

Hi folks. I've put together a PR (#1439) that implements the above. This ended up being a larger change than I had originally anticipated, but hopefully people will agree that these are the "right" changes (or at least trending in that direction).

The best course of action in terms of balancing the specific functionality requested here with the overall design seemed to be to introduce a loss closure abstraction. This allows us to abstract away things like DataLoaders, while also enabling the user to specify custom routines for evaluating their loss functions.

I haven't tested this yet, but I'm hopeful that we'll be able to use e.g. torch.jit.script to compile these closures and expedite training.

Balandat commented 1 year ago

I'm hopeful that we'll be able to use e.g. torch.jit.script to compile these closures and expedite training.

Sam has been having some good success using torchdynamo/torchinductor, would be interesting to see what this does here.