JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
452 stars 52 forks source link

bug: Installation with GPU support #102

Closed seanjkanderson closed 1 year ago

seanjkanderson commented 2 years ago

Bug Report

Moving this from Discussions as not seeing much traffic there. GPJax version: 0.4.10

Current behavior: I can install GPJax with CPU only support but not with GPU support. If I try to install gpjax with cuda support, either based on the current gpjax readthedocs or by first installing Jax according to the latest install with cuda directions:

I've tried to find a matching CUDA-compatible version of jax==0.3.5 but haven't found a version that is also supported by my machine.

I've also tried installing it on a Google Colab notebook with !pip install gpjax and this disables GPU support based on the package versions specified in setup.py.

I'm thinking the setup.py file needs to be updated to a higher jax version and potentially make corresponding updates to changes in jax since 0.3.5, or I'm misunderstanding how to configure GPU support with GPJax.

Expected behavior: I would expect that GPJax does have GPU support.

Steps to reproduce:

Ubuntu 20.04.3 and have tried python 3.8-3.10 run the regression example for instance:

from pprint import PrettyPrinter

import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.example_libraries import optimizers
from jax import jit
import optax as ox

pp = PrettyPrinter(indent=4)
key = jr.PRNGKey(123)

def pred_dist(xtest, y, training, final_params):
    latent_distribution = posterior(training, final_params)(xtest)
    predictive_distribution = likelihood(latent_distribution, final_params)
    return jnp.sum(predictive_distribution.log_prob(y))

if __name__ == '__main__':
    import matplotlib.pyplot as plt

    n = 100
    noise = 0.3

    x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).sort().reshape(-1, 1)
    f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
    signal = f(x)
    y = signal + jr.normal(key, shape=signal.shape) * noise

    D = gpx.Dataset(X=x, y=y)

    xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
    ytest = f(xtest)

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(xtest, ytest, label="Latent function")
    ax.plot(x, y, "o", label="Observations")
    ax.legend(loc="best")

    kernel = gpx.RBF()
    prior = gpx.Prior(kernel=kernel)
    likelihood = gpx.Gaussian(num_datapoints=D.n)
    posterior = prior * likelihood
    params, trainable, constrainer, unconstrainer = gpx.initialise(posterior)
    pp.pprint(params)
    params = gpx.transform(params, unconstrainer)
    mll = jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))
    mll(params)

    opt = ox.adam(learning_rate=0.01)
    final_params = gpx.fit(
        mll,
        params,
        trainable,
        opt,
        n_iters=500,
    )

    final_params = gpx.transform(final_params, constrainer)
    pp.pprint(final_params)

    latent_dist = posterior(D, final_params)(xtest)
    predictive_dist = likelihood(latent_dist, final_params)

    predictive_mean = predictive_dist.mean()
    predictive_std = predictive_dist.stddev()
    # test gradient behavior with respect to posterior
    gtest = jax.grad(pred_dist)
    grad = gtest(xtest, y, D, final_params)

Running this returns the following warning and error:

/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/chex-0.1.3-py3.10.egg/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))
{   'kernel': {   'lengthscale': DeviceArray([1.], dtype=float64),
                  'variance': DeviceArray([1.], dtype=float64)},
    'likelihood': {'obs_noise': DeviceArray([1.], dtype=float64)},
    'mean_function': {}}
Traceback (most recent call last):
  File "/home/paperspace/pycharm_projects/exp_design/gpjax_demo.py", line 47, in <module>
    params = gpx.transform(params, unconstrainer)
  File "/home/paperspace/GPJax/gpjax/parameters.py", line 147, in transform
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 205, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 205, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/paperspace/GPJax/gpjax/parameters.py", line 147, in <lambda>
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/bijectors/lambda_bijector.py", line 112, in inverse
    return self._inverse(y)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/utils/transformations.py", line 119, in wrapped
    out = _interpret_inverse(jaxpr, consts, *args)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/utils/transformations.py", line 250, in _interpret_inverse
    write(jax.core.unitvar, jax.core.unit)
AttributeError: module 'jax.core' has no attribute 'unitvar'

I can try to provide something more minimal, but I'm guessing I'm just missing something more fundamental on install.

In Colab, run !pip install gpjax Then reload as per the prompt import gpjax as gpx At that point it should indicate their is no GPU/TPU found

Thanks!

thomaspinder commented 2 years ago

Hi @seanjkanderson. Thanks for raising this - it's a particularly tricky issue that I've been trying to get to the bottom of for a while now. We've just released v0.4.12 which resolves both of your above issues. To be explicit - you can now run GPJax on Colab GPUs without having to restart the kernel post installation. I have put the below notebook together that will let you get started with this. Please do reach out with any further issues - I'd be interested to hear about your experiences, be it positive or negative!

https://colab.research.google.com/drive/1FWyhE8XLJbWaneDRUNzuZQvNueO-kLdc?usp=sharing

If you are happy that this resolves your issue, then please feel free to close this issue.

seanjkanderson commented 2 years ago

Thanks for the help @thomaspinder. That example works for me in Colab and on a VM running Ubuntu (and I was able to get it working in my project as well).

It seems like you intentionally didn't use gpjax.fit() as tree_map and InferenceState don't seem compatible at the moment? Is that right or is there a new expected behavior/use of .fit()? Just curious as I can use the training approach as described in the notebook.

Cheers!

thomaspinder commented 2 years ago

No problem @seanjkanderson!

fit still works - you just have to be sure to pass it params not the ParameterState object. This is something we're looking to fix in the next few days though. To run fit, you'd simply run

inference_state = gpx.fit(mll, params, trainables, opt)
final_params_fit, history = inference_state.unpack()
final_params_fit = gpx.transform(final_params_fit, constrainer)

I have also added a cell to the above linked notebook with this code in - the only reason I left it out originally was to stay true to the README code.

seanjkanderson commented 2 years ago

Got it, my mistake was passing the ParameterState object to transform(). Thanks @thomaspinder!

seanjkanderson commented 1 year ago

Just got back to using GPJax on Colab/a machine with GPU support. It seems that the install for GPJax is back to uninstalling jaxlib with CUDA support. What's the workaround here @thomaspinder?

Thanks!

Example section of trace from pip install gpjax

Installing collected packages: typeguard, ml-collections, deprecation, jaxtyping, jaxlib, jax, chex, optax, distrax, jaxutils, jaxlinop, jaxkern, gpjax
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    Uninstalling typeguard-2.7.1:
      Successfully uninstalled typeguard-2.7.1
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.25+cuda11.cudnn805
    Uninstalling jaxlib-0.3.25+cuda11.cudnn805:
      Successfully uninstalled jaxlib-0.3.25+cuda11.cudnn805
  Attempting uninstall: jax
    Found existing installation: jax 0.3.25
    Uninstalling jax-0.3.25:
      Successfully uninstalled jax-0.3.25
Successfully installed chex-0.1.5 deprecation-2.1.0 distrax-0.1.2 gpjax-0.5.9 jax-0.4.2 jaxkern-0.0.5 jaxlib-0.4.2 jaxlinop-0.0.3 jaxtyping-0.2.11 jaxutils-0.0.8 ml-collections-0.1.0 optax-0.1.4 typeguard-2.13.3
seanjkanderson commented 1 year ago

The workaround for using a colab notebook seems to be to use:

!pip install gpjax
!pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

where the second reinstalls the cuda version of jaxlib. I would be interested to know if there's a convenient command for installing gpjax with cuda support. I believe the Installation guide says to set CUDA_VERSION=XX before installing gpjax but this didn't seem to work for me in Colab