JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
462 stars 54 forks source link

Conflict between plum-dispatch and cola-plum-dispatch #441

Open vabor112 opened 8 months ago

vabor112 commented 8 months ago

Bug Report

GPJax version: 0.8.0

Current behavior:

I am trying to update the existing integration of GeometricKernels with GPJax so that it works with newer versions of GPJax. It works okay for GPJax 0.6.9. However, for the current GPJax 0.8.0, I hit two problems.

The first one is exactly #397, which, although quite annoying, can be fixed by downgrading tensorflow to version 2.13.

The second one is illustrated in the Related code section below. I believe it is concenred with plum-dispatch, which we use extensively in GeometricKernels to support multiple backends. GPJax uses cola which in its turn relies on a fork of cola, cola-plum-dispatch. This unmaintained fork uses the same namespace plum (which seems like a terrible sin) and gets overriden by the actual plum that GeometricKernels uses, causing the error below. I believe this is similar to this issue.

Expected behavior:

I am not sure how to fix this, but it seems to be an important problem to fix as otherwise GPJax becomes incompatible with any other libraries that rely on plum-dispatch, which is quite popular.

Steps to reproduce:

See below.

Related code:

It is enough to run this snippet:

# Import a backend, we use jax in this example.
import jax.numpy as jnp
import jax
import gpjax as gpx

# Import the geometric_kernels backend.
import geometric_kernels
import geometric_kernels.jax

which leads to

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 4
      2 import jax.numpy as jnp
      3 import jax
----> 4 import gpjax as gpx
      6 # Import the geometric_kernels backend.
      7 import geometric_kernels

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/__init__.py:15
      1 # Copyright 2022 The GPJax Contributors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
---> 15 from gpjax import (
     16     base,
     17     decision_making,
     18     gps,
     19     integrators,
     20     kernels,
     21     likelihoods,
     22     mean_functions,
     23     objectives,
     24     variational_families,
     25 )
     26 from gpjax.base import (
     27     Module,
     28     param_field,
     29 )
     30 from gpjax.citation import cite

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/__init__.py:15
      1 # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
---> 15 from gpjax.decision_making.decision_maker import (
     16     AbstractDecisionMaker,
     17     UtilityDrivenDecisionMaker,
     18 )
     19 from gpjax.decision_making.posterior_handler import PosteriorHandler
     20 from gpjax.decision_making.search_space import (
     21     AbstractSearchSpace,
     22     ContinuousSearchSpace,
     23 )

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/decision_maker.py:32
     29 import jax.random as jr
     31 from gpjax.dataset import Dataset
---> 32 from gpjax.decision_making.posterior_handler import PosteriorHandler
     33 from gpjax.decision_making.search_space import AbstractSearchSpace
     34 from gpjax.decision_making.utility_functions import (
     35     AbstractUtilityFunctionBuilder,
     36     ThompsonSampling,
     37 )

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/posterior_handler.py:25
     23 import gpjax as gpx
     24 from gpjax.dataset import Dataset
---> 25 from gpjax.gps import (
     26     AbstractLikelihood,
     27     AbstractPosterior,
     28     AbstractPrior,
     29 )
     30 from gpjax.objectives import AbstractObjective
     31 from gpjax.typing import KeyArray

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/gps.py:26
     18 from typing import overload
     20 from beartype.typing import (
     21     Any,
     22     Callable,
     23     Optional,
     24     Union,
     25 )
---> 26 import cola
     27 from cola.linalg.decompositions.decompositions import Cholesky
     28 import jax.numpy as jnp

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/__init__.py:11
      9 __all__ = []
     10 # for loader, module_name, is_pkg in  pkgutil.walk_packages(__path__):
---> 11 import_from_all("fns", globals(), __all__, __name__)
     12 import_from_all("annotations", globals(), __all__, __name__)
     13 import_from_all("linalg", globals(), __all__, __name__)

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/utils/__init__.py:36, in import_from_all(module_name, namespace, _all, _name)
     32 def import_from_all(module_name, namespace, _all, _name):
     33     """Import all functions from module.__all__ into the namespace and add to __all__.
     34     example usage: import_every("operators",globals(),__all__,__name__)
     35     """
---> 36     module = importlib.import_module('.' + module_name, package=_name)
     37     if not hasattr(module, "__all__"):
     38         logging.debug(f"empty {module_name}.__all__")

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/importlib/__init__.py:126, in import_module(name, package)
    124             break
    125         level += 1
--> 126 return _bootstrap._gcd_import(name[level:], package, level)

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/fns.py:127
    122 @dispatch
    123 def transpose(A: Dense):
    124     return Dense(A.A.T)
--> 127 @dispatch(cond=lambda A: A.isa(cola.SelfAdjoint))
    128 def transpose(A: LinearOperator):
    129     # dangerous, TODO: fix when A is complex or unify transpose and adjoint
    130     return A
    133 @dispatch
    134 def transpose(A: Triangular):

TypeError: Dispatcher.__call__() got an unexpected keyword argument 'cond'
github-actions[bot] commented 3 months ago

This issue has been marked as stale because it has been open for 7 days with no activity.

aterenin commented 3 months ago

Spoke to @daniel-dodd - this issue is caused by the downstream dependency cola, which GPJax relies on. An issue should be filed there.

aterenin commented 3 months ago

CC: @vabor112

thomaspinder commented 3 months ago

Thanks for updating @aterenin. I'd prefer to see this fixed upstream in Cola, otherwise we may need to fork the project and implement a workaround. Needless to say, this would be messy.

vabor112 commented 2 months ago

I filed an issue with cola developers.