gchq / coreax

A library for coreset algorithms, written in Jax for fast execution and GPU support.
Apache License 2.0
25 stars 2 forks source link

Incompatible with Equinox v0.11.8 #817

Open rg936672 opened 1 month ago

rg936672 commented 1 month ago

What's the problem?

Equinox v0.11.8 causes various TypeErrors in the score matching code. See https://github.com/gchq/coreax/actions/runs/11437260050/job/31816415291?pr=816, or the traceback from a single test (copied from that Actions run) below.

How can we reproduce the issue?

Remove the <0.11.8 restriction on the version of Equinox. Install Coreax alongside Equinox v0.11.8. Run the unit tests.

Python version

3.9, 3.10, 3.11, 3.12

Package version

0.2.1

Operating system

macOS 14.7; Microsoft Windows Server 2022; Ubuntu 22.04.5

Other packages

No response

Relevant log output

_______________ TestSteinThinning.test_reduce[without_jit-True] ________________

self = <tests.unit.test_solvers.TestSteinThinning object at 0x10b858d50>
jit_variant = <function jit_variant.<locals>.<lambda> at 0x37bfa2200>
reduce_problem = _ReduceProblem(dataset=Data(data=f32[128,10], weights=weak_i32[128]), solver=SteinThinning(
  coreset_size=12,
  kerne...=1.0),
  score_matching=None,
  unique=True,
  regularise=True,
  block_size=None,
  unroll=1
), expected_coreset=None)
use_cached_state = True

    @pytest.mark.parametrize("use_cached_state", (False, True))
    def test_reduce(
        self,
        jit_variant: Callable[[Callable], Callable],
        reduce_problem: _ReduceProblem,
        use_cached_state: bool,
    ) -> None:
        """
        Check 'reduce' raises no errors and is resultant 'solver_state' invariant.

        By resultant 'solver_state' invariant we mean the following procedure succeeds:
        1. Call 'reduce' with the default 'solver_state' to get the resultant state
        2. Call 'reduce' again, this time passing the 'solver_state' from the previous
            run, and keeping all other arguments the same.
        3. Check the two calls to 'refine' yield that same result.
        """
        dataset, solver, _ = reduce_problem
>       coreset, state = jit_variant(solver.reduce)(dataset)

tests/unit/test_solvers.py:162: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
coreax/solvers/coresubset.py:335: in reduce
    return self.refine(initial_coresubset, solver_state)
coreax/solvers/coresubset.py:353: in refine
    kernel = convert_stein_kernel(x, self.kernel, self.score_matching)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = Array([[0.90712607, 0.885797  , 0.10075748, ..., 0.9664997 , 0.00239444,
        0.41092277],
       [0.23769128, 0.13...2072847],
       [0.37630963, 0.03963459, 0.25985646, ..., 0.2631793 , 0.6154243 ,
        0.81453335]], dtype=float32)
kernel = PCIMQKernel(length_scale=1.0, output_scale=1.0), score_matching = None

    def convert_stein_kernel(
        x: Shaped[Array, " n d"],
        kernel: ScalarValuedKernel,
        score_matching: Union[ScoreMatching, None],
    ) -> SteinKernel:
        r"""
        Convert the kernel to a :class:`~coreax.kernels.SteinKernel`.

        :param x: The data used to call `score_matching.match(x)`
        :param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
            kernel function
            :math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`; if 'kernel'
            is a :class:`~coreax.kernels.SteinKernel` and :code:`score_matching is not
            data:`None`, a new instance of the kernel will be generated where the score
            function is given by :code:`score_matching.match(x)`
        :param score_matching: Specifies/overwrite the score function of the implied/passed
           :class:`~coreax.kernels.SteinKernel`; if :data:`None`, default to
           :class:`~coreax.score_matching.KernelDensityMatching` unless 'kernel' is a
           :class:`~coreax.kernels.SteinKernel`, in which case the kernel's existing score
           function is used.
        :return: The (potentially) converted/updated :class:`~coreax.kernels.SteinKernel`.
        """
        if isinstance(kernel, SteinKernel):
            if score_matching is not None:
                _kernel = eqx.tree_at(
                    lambda x: x.score_function, kernel, score_matching.match(x)
                )
            else:
                _kernel = kernel
        else:
            if score_matching is None:
                length_scale = getattr(kernel, "length_scale", 1.0)
>               score_matching = KernelDensityMatching(length_scale)
E               TypeError: __class__ assignment: 'KernelDensityMatching' object layout differs from 'KernelDensityMatching'

coreax/score_matching.py:612: TypeError
tp832944 commented 2 weeks ago

Might be caused by #582 .