Open adam-hartshorne opened 1 year ago
Hey there. This looks like a bug in CoLA. The issue is that they're using __init__
inside the unflatten rule of their pytrees: https://github.com/wilson-labs/cola/blob/50d61041b94b0f092e89cc05169538b6e9caeeba/cola/ops/operator_base.py#L250
This is a common mistake when implementing custom pytrees in JAX; see the documentation here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
I'd recommend reporting the bug to CoLA.
They could fix this by either (a) having their LinearOperator
inherit from eqx.Module
, or (b) unflattening via __new__
and __setattr__
; see the Equinox implementation here for reference.
Incidentally, it's interesting to see another JAX-compatible library for linear operators. We also have Lineax, which is a part of the JAx+Equinox ecosystem, and provides both linear operators and linear solves. It looks like the two libaries have substantially overlapping (but not identical) functionality. For example Lineax supports pseudoinverses, but CoLA supports eigendecompositions.
Thanks for the fast response. I will report it on their github.
The following Linear Operator Library has recently been released, https://cola.readthedocs.io/en/latest/index.html, which allows for lazy evaluation in the following manner,
lazy_A = cola.fns.lazify(A)
where the type of lazy_A will be a cola.ops.LinearOperator. Overall the library is incredibly powerful and I only foresee it becoming a core tool within the JAX ecosystem.
I was wondering if it would be possible to store this as a field in an equinox Module, as the following MVE produces an error when attempting to call the filter functionality,
The problem is that the "type" of the lazy variable is,
<10x10 Dense with dtype=float32> I was wondering if there is an easy way around this, such that it makes using equinox compatible with cola lazified variables?