JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

bug: The latest version of jax causes a TypeError when making predictions. #463

Closed mikkelbue closed 3 weeks ago

mikkelbue commented 1 month ago

Bug Report

GPJax version:

0.8.2

Current behavior:

Making predictions with gpjax.gps.Prior and gpjax.gps.ConjugatePosterior (and possibly others) raises a TypeError. This happens because of an update to jax.Array in version 0.4.31 of jax.

Expected behavior:

The predict methods of said distributions should not raise an error.

Steps to reproduce:

Install gpjax==0.8.2 from pypi and run e.g. the gpjax [regression example].(https://github.com/JaxGaussianProcesses/GPJax/blob/main/docs/examples/regression.py)

Here is the Traceback from regression.py:

Traceback (most recent call last):
  File "/home/mikkel/git/GPJax/docs/examples/regression.py", line 122, in <module>
    prior_dist = prior.predict(xtest)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
    out = fn(*args, **kwargs)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/gps.py", line 278, in predict
    Kxx = self.kernel.gram(x)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
    out = fn(*args, **kwargs)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/kernels/base.py", line 64, in gram
    return self.compute_engine.gram(self, x)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
    out = fn(*args, **kwargs)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/kernels/computations/base.py", line 57, in gram
    return PSD(Dense(Kxx))
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 22, in __new__
    obj.device = find_device([args, kwargs])
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 272, in find_device
    device = find_device(ob)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 272, in find_device
    device = find_device(ob)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 267, in find_device
    return xnp.get_device(obj)
  File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/backends/jax_fns.py", line 159, in get_device
    return array.device()
TypeError: 'Device' object is not callable

Other information:

I believe the last entry in the jax changelog for version 0.4.31 is related to this issue.

thomaspinder commented 3 weeks ago

Thanks for flagging @mikkelbue - hopefully this is resolved in v0.9.0