automl / ifBO

In-context Bayesian Optimization
MIT License
12 stars 2 forks source link

The example notebook does not work with GPU #2

Closed HideakiImamura closed 3 months ago

HideakiImamura commented 3 months ago

Expected behavior

The example notebook here should be executable with the GPU available environment. I would like to send PRs to fix these issues.

Environment

```bash ll-workload-v1-zjnmw in ifBO/examples on 🌵 main [!?] > pip list Package Version Editable project location ------------------------------------- --------------------- ------------------------------ absl-py 2.1.0 accelerate 0.30.1 aiofiles 23.2.1 aiosignal 1.3.1 alembic 1.13.1 altair 5.3.0 annotated-types 0.6.0 ansicolors 1.1.8 anyio 4.3.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 attrs 23.2.0 audioread 3.0.1 awscli 1.32.103 Babel 2.15.0 backports.tarfile 1.1.1 beautifulsoup4 4.12.3 bitsandbytes 0.43.1 black 24.4.2 bleach 6.1.0 blosc2 2.6.2 boto3 1.34.103 botocore 1.34.103 bottle 0.12.25 cachetools 5.3.3 certifi 2024.2.2 cffi 1.16.0 chainer 7.8.1 charset-normalizer 3.3.2 click 8.0.2 cloudpickle 3.0.0 cmaes 0.10.0 cmake 3.29.3 colorama 0.4.6 colorlog 4.8.0 comm 0.2.2 contourpy 1.2.1 copt 0.9.1 cryptography 42.0.7 cupy-cuda12x 13.1.0 cycler 0.12.1 dacite 1.8.1 debugpy 1.8.1 decorator 5.1.1 defusedxml 0.7.1 deprecation 2.1.0 diskcache 5.6.3 dnspython 2.6.1 docker 7.0.0 docutils 0.16 einops 0.8.0 email_validator 2.1.1 empyrical 0.5.5 entrypoints 0.4 executing 2.0.1 fastapi 0.111.0 fastapi-cli 0.0.3 fastjsonschema 2.19.1 fastrlock 0.8.2 ffmpy 0.3.2 filelock 3.14.0 flake8 7.0.0 flake8-bugbear 24.4.26 flatbuffers 24.3.25 fonttools 4.51.0 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.3.1 ftfy 6.2.0 gast 0.5.4 gitdb 4.0.11 GitPython 3.1.43 google-auth 2.29.0 google-pasta 0.2.0 gradio 4.31.0 gradio_client 0.16.2 greenlet 3.0.3 grpcio 1.63.0 gunicorn 22.0.0 h11 0.14.0 h5py 3.11.0 httpcore 1.0.5 httplib2 0.22.0 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.23.0 idna 3.7 ifBO 0.3.0 PATH/TO/ifBO importlib_metadata 7.1.0 importlib_resources 6.4.0 iniconfig 2.0.0 inquirerpy 0.3.4 interegular 0.3.3 ipykernel 6.29.4 ipynb-py-convert 0.4.6 ipython 8.24.0 ipywidgets 8.1.2 isoduration 20.11.0 isort 5.13.2 jaraco.classes 3.4.0 jaraco.context 5.3.0 jaraco.functools 4.0.1 jax 0.4.28 jaxlib 0.4.28+cuda12.cudnn89 jedi 0.19.1 jeepney 0.8.0 Jinja2 3.1.4 jmespath 1.0.1 joblib 1.4.2 json5 0.9.25 jsonpointer 2.4 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 jupyter_client 8.6.1 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.0 jupyter_server_terminals 0.5.3 jupyterlab 4.1.8 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.1 jupyterlab_widgets 3.0.10 kaleido 0.2.1 keras 3.3.3 keyring 25.2.0 keyrings.google-artifactregistry-auth 1.1.2 kiwisolver 1.4.5 kubernetes-models-pfn 0.10.0 lark 1.1.9 lazy_loader 0.4 libclang 18.1.1 librosa 0.10.2 llvmlite 0.42.0 logzero 1.7.0 lxml 5.2.2 Mako 1.3.3 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.8.4 matplotlib-inline 0.1.7 mccabe 0.7.0 mdurl 0.1.2 minai 0.51.1 mistune 3.0.2 ml-dtypes 0.3.2 more-itertools 10.2.0 mpmath 1.3.0 msgpack 1.0.8 mypy 1.10.0 mypy-extensions 1.0.0 namex 0.0.8 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 ndindex 1.8 nest-asyncio 1.6.0 networkx 3.3 nglview 3.1.2 ninja 1.11.1.1 notebook 7.1.3 notebook_shim 0.2.4 numba 0.59.1 numexpr 2.10.0 numpy 1.26.4 openai-whisper 20231117 opt-einsum 3.3.0 optree 0.11.0 optuna 3.6.1 optuna-dashboard 0.15.1 optuna-integration 3.6.0 oras 0.1.26 orjson 3.10.3 outlines 0.0.34 overrides 7.7.0 packaging 24.0 pandas 2.2.2 pandas-datareader 0.10.0 pandocfilters 1.5.1 papermill 2.6.0 parso 0.8.4 pathspec 0.12.1 pexpect 4.9.0 pfio 2.8.0 pfzy 0.3.4 pillow 10.3.0 pip 23.3.2 platformdirs 4.2.1 plotly 5.22.0 pluggy 1.5.0 pooch 1.8.1 prometheus_client 0.20.0 prompt-toolkit 3.0.43 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyasn1 0.6.0 pyasn1_modules 0.4.0 pycodestyle 2.11.1 pycparser 2.22 pydantic 2.7.1 pydantic_core 2.18.2 pydub 0.25.1 pyflakes 3.2.0 Pygments 2.18.0 PyMySQL 1.1.0 pynvml 11.5.0 pyparsing 3.1.2 pysen 0.11.0 pytest 8.2.0 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 python-json-logger 2.0.7 python-multipart 0.0.9 pytorch-pfn-extras 0.7.6 pytz 2024.1 PyYAML 6.0.1 pyzmq 26.0.3 ray 2.21.0 referencing 0.35.1 regex 2024.5.10 requests 2.31.0 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.7.1 rpds-py 0.18.1 rsa 4.7.2 ruff 0.4.4 s3transfer 0.10.1 safetensors 0.4.3 scikit-learn 1.4.2 scipy 1.14.0 seaborn 0.13.2 SecretStorage 3.3.3 semantic-version 2.10.0 Send2Trash 1.8.3 sentencepiece 0.2.0 setuptools 69.5.1 shellingham 1.5.4 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 soundfile 0.12.1 soupsieve 2.5 sox 1.5.0 soxr 0.3.7 SQLAlchemy 2.0.30 stack-data 0.6.3 starlette 0.37.2 sympy 1.12 tables 3.9.2 tenacity 8.3.0 tensorboard 2.16.2 tensorboard-data-server 0.7.2 tensorflow 2.16.1 tensorflow-io 0.37.0 tensorflow-io-gcs-filesystem 0.37.0 termcolor 2.4.0 terminado 0.18.1 threadpoolctl 3.5.0 tiktoken 0.6.0 tinycss2 1.3.0 tokenizers 0.19.1 tomlkit 0.12.0 toolz 0.12.1 torch 2.1.2+cu121 torchaudio 2.1.2+cu121 torchvision 0.16.2+cu121 tornado 6.4 tqdm 4.66.4 traitlets 5.14.3 transformers 4.40.2 triton 2.1.0 typeguard 4.2.1 typer 0.12.3 types-python-dateutil 2.9.0.20240316 typing_extensions 4.11.0 tzdata 2024.1 ujson 5.9.0 unidiff 0.7.5 uri-template 1.3.0 urllib3 2.2.1 uvicorn 0.29.0 uvloop 0.19.0 vllm 0.4.0.post1 watchfiles 0.21.0 wcwidth 0.2.13 webcolors 1.13 webencodings 0.5.1 websocket-client 1.8.0 websockets 11.0.3 Werkzeug 3.0.3 wheel 0.43.0 widgetsnbextension 4.0.10 wrapt 1.16.0 xarray 2024.5.0 xformers 0.0.23.post1 zipp 3.18.1 ```

Error messages, stack traces, or logs

There are several reasons why we cannot execute the notebook. I will breakdown the reasons as follows.

The call of context, query = ifbo.utils.detokenize(batch, context_size=single_eval_pos) fails as follows.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[8], line 1
----> 1 context, query = ifbo.utils.detokenize(batch, context_size=single_eval_pos)

AttributeError: module 'ifbo.priors.utils' has no attribute 'detokenize'

This is because of the unintended module structures. We should fix this line not disclose priors.utils directly. https://github.com/automl/ifBO/blob/main/ifbo/__init__.py#L9

In addition, if we fix the above line, the following errors occur in the call of curve.hyperparameters[:3].numpy().tolist().

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 2
      1 for curve in context:
----> 2     plt.plot(curve.t, curve.y, color=curve.hyperparameters[:3].numpy().tolist() + [0.5])
      3 for curve in query:
      4     plt.scatter(curve.t, curve.y, color=curve.hyperparameters[:3].numpy().tolist() + [0.5], s=7, marker="*")

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

This is because of the Tensor object in context and query are supposed to be created on the CPU device, but it actually created on GPU if it's available. We should fix the instance creation of those tensors, or moving them onto CPU when calling .numpy() methods.

In addition, if we fix the above line as curve.hyperparameters[:3].cpu().numpy().tolist(), the following errors occur in the call of predictions = model.predict(context=context, query=query).

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 8
      5     for hp in curve.hyperparameters:
      6         assert hp.get_device() == 0
----> 8 predictions = model.predict(context=context, query=query)
      9 predictions[0].ucb().shape, predictions[0].quantile(0.5).shape

File [/usr/local/lib/python3.11/site-packages/torch/utils/_contextlib.py:115](https://user-mamu--ll-workload-v1-zjnmw--12322.mnjg2.ingress.landlord.cluster-services.nf10.net/usr/local/lib/python3.11/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [/mnt/nfs-mnj-home-43/mamu/ifBO/ifbo/surrogate.py:91](https://user-mamu--ll-workload-v1-zjnmw--12322.mnjg2.ingress.landlord.cluster-services.nf10.net/lab/tree/ifBO/examples/ifBO/ifbo/surrogate.py#line=90), in FTPFN.predict(self, context, query)
     82 @torch.no_grad()
     83 def predict(
     84     self, context: List[Curve], query: List[Curve]
     85 ) -> List[PredictionResult]:
     86     """Obtain the logits for the given context and query curves.
     87 
     88     Function to perform Bayesian inference using FT-PFN that uses the logits obtained to 
     89     compute various measures like likelihood, UCB, EI, PI, and quantile.
     90     """
---> 91     x_train, y_train, x_test = tokenize(context, query)
     92     logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
     93     results = torch.split(logits, [len(curve.t) for curve in query], dim=0)

File [/mnt/nfs-mnj-home-43/mamu/ifBO/ifbo/utils.py:463](https://user-mamu--ll-workload-v1-zjnmw--12322.mnjg2.ingress.landlord.cluster-services.nf10.net/lab/tree/ifBO/examples/ifBO/ifbo/utils.py#line=462), in tokenize(context, query)
    460     num_points = curve.t.size(0)
    461     for i in range(num_points):
    462         context_tokens.append(
--> 463             torch.cat(
    464                 (torch.tensor([curve_id, curve.t[i].item()]), curve.hyperparameters)
    465             )
    466         )
    467         context_y_values.append(curve.y[i])
    469 for curve in query:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

This is because that several tensors are created on CPU but others are created on GPU. This behavior should be fixed.

Steps to reproduce

Run the script on the environment where a GPU is available.