JaxGaussianProcesses / JaxLinOp

Linear Operators in JAX.
Apache License 2.0
12 stars 0 forks source link

bug: `static_field`s under `dataclass` inheritance are treated a dynamic. #25

Closed daniel-dodd closed 1 year ago

daniel-dodd commented 1 year ago

This behaviour is explained at https://github.com/cgarciae/simple-pytree/issues/8#issue-1630486498

Even though e.g., the shape and dtype are marked as static on the LinearOperator base class, flattening, e.g., a DenseLinearOperator, contains these fields.

from jaxlinop import DenseLinearOperator
import jax.numpy as jnp
import jax.tree_util as jtu

jtu.tree_flatten(DenseLinearOperator(jnp.array([[1.0, 2.0], [3.0, 4.0]])))
([Array([[1., 2.],
         [3., 4.]], dtype=float32),
  2,
  2,
  dtype('float32')],
 PyTreeDef(CustomNode(DenseLinearOperator[(('matrix', 'shape', 'dtype'), {'_pytree__initialized': True})], [*, (*, *), *])))

Hopefully this will be resolved in simple-pytree, if not we need a work around. To complete this issue we also need to ensure we test for this behaviour.