Even though e.g., the shape and dtype are marked as static on the LinearOperator base class, flattening, e.g., a DenseLinearOperator, contains these fields.
from jaxlinop import DenseLinearOperator
import jax.numpy as jnp
import jax.tree_util as jtu
jtu.tree_flatten(DenseLinearOperator(jnp.array([[1.0, 2.0], [3.0, 4.0]])))
Hopefully this will be resolved in simple-pytree, if not we need a work around. To complete this issue we also need to ensure we test for this behaviour.
This behaviour is explained at https://github.com/cgarciae/simple-pytree/issues/8#issue-1630486498
Even though e.g., the
shape
anddtype
are marked as static on theLinearOperator
base class, flattening, e.g., aDenseLinearOperator
, contains these fields.Hopefully this will be resolved in
simple-pytree
, if not we need a work around. To complete this issue we also need to ensure we test for this behaviour.