Making predictions with gpjax.gps.Prior and gpjax.gps.ConjugatePosterior (and possibly others) raises a TypeError. This happens because of an update to jax.Array in version 0.4.31 of jax.
Expected behavior:
The predict methods of said distributions should not raise an error.
Traceback (most recent call last):
File "/home/mikkel/git/GPJax/docs/examples/regression.py", line 122, in <module>
prior_dist = prior.predict(xtest)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
out = fn(*args, **kwargs)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/gps.py", line 278, in predict
Kxx = self.kernel.gram(x)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
out = fn(*args, **kwargs)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/kernels/base.py", line 64, in gram
return self.compute_engine.gram(self, x)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
out = fn(*args, **kwargs)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/gpjax/kernels/computations/base.py", line 57, in gram
return PSD(Dense(Kxx))
File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 22, in __new__
obj.device = find_device([args, kwargs])
File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 272, in find_device
device = find_device(ob)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 272, in find_device
device = find_device(ob)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/ops/operator_base.py", line 267, in find_device
return xnp.get_device(obj)
File "/home/mikkel/venv/test/lib/python3.10/site-packages/cola/backends/jax_fns.py", line 159, in get_device
return array.device()
TypeError: 'Device' object is not callable
Bug Report
GPJax version:
0.8.2
Current behavior:
Making predictions with
gpjax.gps.Prior
andgpjax.gps.ConjugatePosterior
(and possibly others) raises aTypeError
. This happens because of an update tojax.Array
in version0.4.31
ofjax
.Expected behavior:
The
predict
methods of said distributions should not raise an error.Steps to reproduce:
Install
gpjax==0.8.2
from pypi and run e.g. thegpjax
[regression example].(https://github.com/JaxGaussianProcesses/GPJax/blob/main/docs/examples/regression.py)Here is the Traceback from
regression.py
:Other information:
I believe the last entry in the
jax
changelog for version0.4.31
is related to this issue.