facebookresearch / alebo

Re-Examining Linear Embeddings for High-dimensional Bayesian Optimization
Other
41 stars 5 forks source link

CUDA Support? #2

Open LeoIV opened 3 years ago

LeoIV commented 3 years ago

Hi, do you support CUDA? I tried initializing an AleboStrategy with device='cuda' but it eventually results in an error. I'm wondering if it's my objective function or if CUDA is generally not supported?

[ERROR 02-26 14:19:00] ax.service.managed_loop: Tensor for argument #3 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for addmm)
Traceback (most recent call last):
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/utils/common/executils.py", line 98, in actual_wrapper
    return func(*args, **kwargs)
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/service/managed_loop.py", line 153, in run_trial
    generator_run=self.generation_strategy.gen(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/modelbridge/generation_strategy.py", line 406, in gen
    return self._gen_multiple(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/modelbridge/generation_strategy.py", line 511, in _gen_multiple
    generator_run = model.gen(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/modelbridge/base.py", line 628, in gen
    observation_features, weights, best_obsf, gen_metadata = self._gen(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/modelbridge/array.py", line 250, in _gen
    X, w, gen_metadata, candidate_metadata = self._model_gen(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/modelbridge/torch.py", line 219, in _model_gen
    X, w, gen_metadata, candidate_metadata = self.model.gen(
  File "/home/leonard/miniconda3/envs/proj/lib/python3.9/site-packages/ax/models/torch/alebo.py", line 668, in gen
    Xopt = (self.Binv @ Xd_opt.t()).t()
RuntimeError: Tensor for argument #3 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for addmm)