Remove the <0.11.8 restriction on the version of Equinox. Install Coreax alongside Equinox v0.11.8. Run the unit tests.
Python version
3.9, 3.10, 3.11, 3.12
Package version
0.2.1
Operating system
macOS 14.7; Microsoft Windows Server 2022; Ubuntu 22.04.5
Other packages
No response
Relevant log output
_______________ TestSteinThinning.test_reduce[without_jit-True] ________________
self = <tests.unit.test_solvers.TestSteinThinning object at 0x10b858d50>
jit_variant = <function jit_variant.<locals>.<lambda> at 0x37bfa2200>
reduce_problem = _ReduceProblem(dataset=Data(data=f32[128,10], weights=weak_i32[128]), solver=SteinThinning(
coreset_size=12,
kerne...=1.0),
score_matching=None,
unique=True,
regularise=True,
block_size=None,
unroll=1
), expected_coreset=None)
use_cached_state = True
@pytest.mark.parametrize("use_cached_state", (False, True))
def test_reduce(
self,
jit_variant: Callable[[Callable], Callable],
reduce_problem: _ReduceProblem,
use_cached_state: bool,
) -> None:
"""
Check 'reduce' raises no errors and is resultant 'solver_state' invariant.
By resultant 'solver_state' invariant we mean the following procedure succeeds:
1. Call 'reduce' with the default 'solver_state' to get the resultant state
2. Call 'reduce' again, this time passing the 'solver_state' from the previous
run, and keeping all other arguments the same.
3. Check the two calls to 'refine' yield that same result.
"""
dataset, solver, _ = reduce_problem
> coreset, state = jit_variant(solver.reduce)(dataset)
tests/unit/test_solvers.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
coreax/solvers/coresubset.py:335: in reduce
return self.refine(initial_coresubset, solver_state)
coreax/solvers/coresubset.py:353: in refine
kernel = convert_stein_kernel(x, self.kernel, self.score_matching)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = Array([[0.90712607, 0.885797 , 0.10075748, ..., 0.9664997 , 0.00239444,
0.41092277],
[0.23769128, 0.13...2072847],
[0.37630963, 0.03963459, 0.25985646, ..., 0.2631793 , 0.6154243 ,
0.81453335]], dtype=float32)
kernel = PCIMQKernel(length_scale=1.0, output_scale=1.0), score_matching = None
def convert_stein_kernel(
x: Shaped[Array, " n d"],
kernel: ScalarValuedKernel,
score_matching: Union[ScoreMatching, None],
) -> SteinKernel:
r"""
Convert the kernel to a :class:`~coreax.kernels.SteinKernel`.
:param x: The data used to call `score_matching.match(x)`
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`; if 'kernel'
is a :class:`~coreax.kernels.SteinKernel` and :code:`score_matching is not
data:`None`, a new instance of the kernel will be generated where the score
function is given by :code:`score_matching.match(x)`
:param score_matching: Specifies/overwrite the score function of the implied/passed
:class:`~coreax.kernels.SteinKernel`; if :data:`None`, default to
:class:`~coreax.score_matching.KernelDensityMatching` unless 'kernel' is a
:class:`~coreax.kernels.SteinKernel`, in which case the kernel's existing score
function is used.
:return: The (potentially) converted/updated :class:`~coreax.kernels.SteinKernel`.
"""
if isinstance(kernel, SteinKernel):
if score_matching is not None:
_kernel = eqx.tree_at(
lambda x: x.score_function, kernel, score_matching.match(x)
)
else:
_kernel = kernel
else:
if score_matching is None:
length_scale = getattr(kernel, "length_scale", 1.0)
> score_matching = KernelDensityMatching(length_scale)
E TypeError: __class__ assignment: 'KernelDensityMatching' object layout differs from 'KernelDensityMatching'
coreax/score_matching.py:612: TypeError
What's the problem?
Equinox v0.11.8 causes various TypeErrors in the score matching code. See https://github.com/gchq/coreax/actions/runs/11437260050/job/31816415291?pr=816, or the traceback from a single test (copied from that Actions run) below.
How can we reproduce the issue?
Remove the
<0.11.8
restriction on the version of Equinox. Install Coreax alongside Equinox v0.11.8. Run the unit tests.Python version
3.9, 3.10, 3.11, 3.12
Package version
0.2.1
Operating system
macOS 14.7; Microsoft Windows Server 2022; Ubuntu 22.04.5
Other packages
No response
Relevant log output