Open mhrmsn opened 2 months ago
Last point is now addressed in this PR: pytorch/botorch#2502
Hi @mhrmsn, thanks for investigating and raising the botorch issue 🥇
Solid GPU support is definitely something we should target and in particular also test – which brings up yet another point to think about. That involves both figuring out what are reasonable GPU tests and also how/where we can run them (AFAIK Github now offers GPU runners, but I have no clue if we can get easy access).
Regarding the other points you mentions:
torch.set_default_device("cuda")
is something user can then easily do once the above points are addressed. And agree: I've already thought for quite a while to add a "settings" mechanism (also to control other things like floating point precision, etc). I have a few ideas in mind but haven't really had the time to implement it yet. Perhaps we can talk about it once I'm back from vacation?
As I recently looked into this and after discussion with @AdrianSosic and @Scienfitz, here's my observations:
I think there are two convenient ways to support GPUs, either by allowing the user to use
torch.set_default_device("cuda")
or by adding a configuration variable within the package, something likebaybe.options.device = "cuda"
. I only tested the first one, I think the latter is a bit more complex, but would potentially allow to have different parts of code in the overall BayBe workflow to use different devices (this may be useful if you are generating embeddings from a torch-based neural network for use in BayBe etc.)When experimenting with
torch.set_default_device("cuda")
, I noticed that the devices for the tensors are not consistently set in BayBE. For either solution I think these points would need to be addressed:torch.from_numpy()
calls that are used to construct botorch inputs share the same memory as the input, same goes fortorch.frombuffer()
(not used in BayBE). See also: https://pytorch.org/docs/stable/generated/torch.set_default_device.htmlpd.DataFrame(points, ...)
wherepoints
is a tensor from botorch. This will fail if the tensor is not on the CPU.