pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.11k stars 406 forks source link

[Bug] Implicit assumption of double precision can cause failures when single precision is used #2596

Open AVHopp opened 1 month ago

AVHopp commented 1 month ago

🐛 Bug

Within test_functions/base.py, the bounds are hard-coded to double precision. This makes it impossible to use single-precision which is necessary for e.g. MPS support on Mac.

To reproduce

import torch
from botorch.test_functions import Rastrigin

torch.set_default_device("mps")
torch.set_default_dtype(torch.float32)

test = Rastrigin()

This yields the following error message (paths shortened for better readability):

Traceback (most recent call last):
  File "[...]/test.py", line 7, in <module>
    test = Rastrigin()
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/synthetic.py", line 664, in __init__
    super().__init__(noise_std=noise_std, negate=negate, bounds=bounds)
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/synthetic.py", line 83, in __init__
    super().__init__(noise_std=noise_std, negate=negate)
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/base.py", line 51, in __init__
    "bounds", torch.tensor(self._bounds, dtype=torch.double).transpose(-1, -2)
  File "[...]/lib/python3.10/site-packages/torch/utils/_device.py", line 79, in __torch_function__
    return func(*args, **kwargs)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Expected Behavior

The initialization of the BaseTestProblem class should not enforce double precision in the buffer for the bounds but probably use torch.get_default_dtype()

System information

Please complete the following information:

Additional context

These seem to be the problematic lines: https://github.com/pytorch/botorch/blob/9d37e905639e0d4983e52ce425306b8161760ee4/botorch/test_functions/base.py#L49-L51

If the error is in fact just this one line of code, I'd be more than happy to create a mini Pull Request if it makes sense :)

AVHopp commented 1 month ago

NOTE: I have verified locally that changing the line of code that I mentioned under "Additional context" solves the problem.

Balandat commented 1 month ago

The initialization of the BaseTestProblem class should not enforce double precision in the buffer for the bounds but probably use torch.get_default_dtype()

Yep, that makes sense. If self._bounds are python floats then this will automatically happen, but I guess some of these might be defined as int, so using torch.get_default_dtype() explicitly here seems good so we don't end up with int tensors.

If the error is in fact just this one line of code, I'd be more than happy to create a mini Pull Request if it makes sense :)

That would be very welcome!

AVHopp commented 1 month ago

Great, then I'll draft the PR :)

Balandat commented 1 month ago

Great. Btw, have you been using BoTorch more generally with MPS on Mac? I have tried this in the past but a lot of operations that we (and gpytorch) use under the hood weren't supported by MPS at the time, curious if that has changed more recently.

AVHopp commented 1 month ago

I am currently trying to bring GPU support to a package that I co-develop (happy to share the link if shameless self-advertisement is fine 😆 ) that uses BoTorch in the backend. This issue here was the first that I encountered, and I haven't investigated in more detail until now. I can however share whatever I find here or at any other place in case that I find more issues/errors if you are interested 😄

AVHopp commented 1 month ago

@Balandat I found another occurrence of the same issue:

import numpy as np
import torch
from botorch.optim.utils.numpy_utils import set_tensors_from_ndarray_1d
from torch import tensor

torch.set_default_device("mps")
tensors = (tensor([-2.2532]),)

array = np.asarray([-2.25321913], dtype=np.float64)

set_tensors_from_ndarray_1d(tensors, array)

This fails since the to_tensor call in the function does not take the default dtype into account:https://github.com/pytorch/botorch/blob/9d37e905639e0d4983e52ce425306b8161760ee4/botorch/optim/utils/numpy_utils.py#L113-L128 The error happens in the last line, and replacing this line either by torch.as_tensor(...) or the default value to a whacky lambda expression as_tensor: Callable[[ndarray], Tensor] = lambda x: torch.as_tensor(x, dtype=torch.get_default_dtype()) seems to solve the issue (at least on the BoTorich version I pointed out here).

It seems like there are several places where there is an implicit assumption on using double precision/not the default one if something is enforced. Should I just create small PRs whenever I find these or is there a better way to go forward?

Balandat commented 1 month ago

It seems like there are several places where there is an implicit assumption on using double precision/not the default one if something is enforced.

The main reason this implicit assumption is present in various places in the codebase is that we've found that when working with GPs we often end up with rather ill-conditioned matrices, and performing things like matrix decompositions or linear solves on those with FP32 precision can be quite hairy and result in poor accuracy and often also outright failures.

Should I just create small PRs whenever I find these or is there a better way to go forward?

That would be great and I think it makes sense to fix those as they come up.

Balandat commented 1 month ago

Actually, regarding the issue about set_tensors_from_ndarray_1d: we did change this in the past, IIRC in response to some flaky tests: https://github.com/pytorch/botorch/pull/1508/files - That doesn't mean the current setup is the right one though. cc @esantorella

The error happens in the last line, and replacing this line either by torch.as_tensor(...) or the default value to a whacky lambda expression as_tensor: Callable[[ndarray], Tensor] = lambda x: torch.as_tensor(x, dtype=torch.get_default_dtype()) seems to solve the issue (at least on the BoTorich version I pointed out here).

Looking at the code I don't actually ever see the as_tensor arg of NdarrayOptimizationClosure be used; seems this is always the default. Given that I would probably recommend just getting rid of that additional complexity and simply use torch.as_tensor() in the body of the function, and use the device and dtype of the input tensors as the arguments to it.

esantorella commented 1 month ago

Regarding the first issue, with the bounds in test functions, why use torch.get_default_dtype? Could we, say, pass a dtype argument to the function that defaults to torch.float64? The proposal to use torch.get_default_dtype will change the behavior and could potentially break some existing code, since torch.get_default_dtype is torch.float32 when not otherwise set.

esantorella commented 1 month ago

Regarding set_tensors_from_ndarray_1d, that cast to double-precision was added along with this test. Basically, when optimization failed, NaNs would be returned and generated some sort of dtype-related error. So I'd make sure that test passes when making changes to that function.

I don't actually ever see the as_tensor arg of NdarrayOptimizationClosure be used; seems this is always the default. Given that I would probably recommend just getting rid of that additional complexity

This makes sense to me.

AVHopp commented 1 month ago

@esantorella Regarding the first point: I'll try to design a solution for this in the currently existing PR #2597 and I think it makes sense to continue the discussion there. For the second one, I'll also see if I can come up with a solution, either in the same PR or in another one (which I guess would be more favorable since this is a completely different part of the code?)

Balandat commented 1 month ago

For the second one, I'll also see if I can come up with a solution, either in the same PR or in another one (which I guess would be more favorable since this is a completely different part of the code?)

That would be great and yes this would best happen in a separate PR.