patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.11k stars 142 forks source link

Storing CoLA Linear Operators as a Field in an Module #453

Open adam-hartshorne opened 1 year ago

adam-hartshorne commented 1 year ago

The following Linear Operator Library has recently been released, https://cola.readthedocs.io/en/latest/index.html, which allows for lazy evaluation in the following manner,

lazy_A = cola.fns.lazify(A)

where the type of lazy_A will be a cola.ops.LinearOperator. Overall the library is incredibly powerful and I only foresee it becoming a core tool within the JAX ecosystem.

I was wondering if it would be possible to store this as a field in an equinox Module, as the following MVE produces an error when attempting to call the filter functionality,

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import optax
import equinox as eqx
import cola

class MyModule(eqx.Module):

    lazy_A : cola.ops.LinearOperator

    def __init__(self, A):
        self.lazy_A = cola.fns.lazify(A)

    def __call__(self, x):
        return self.lazy_A @ x

seed = jr.PRNGKey(0)
A = jr.normal(seed, (10, 10))
X = jnp.ones((10, 1))
model = MyModule(A)
result = eqx.filter(model, eqx.is_inexact_array)
  File "/media/adam/shared_drive/PycharmProjects/test_equinox_lazy_variable/test_equinox_lazy_variable.py", line 24, in <module>
    result = eqx.filter(model, eqx.is_inexact_array)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter
    filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 71, in _filter_tree
    return jtu.tree_map(mask, arg, is_leaf=is_leaf)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 245, in tree_unflatten
    return cls(*new_args, **aux[0])
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 21, in __init__
    super().__init__(dtype=A.dtype, shape=A.shape)
AttributeError: 'bool' object has no attribute 'dtype'

The problem is that the "type" of the lazy variable is,

<10x10 Dense with dtype=float32> I was wondering if there is an easy way around this, such that it makes using equinox compatible with cola lazified variables?
patrick-kidger commented 1 year ago

Hey there. This looks like a bug in CoLA. The issue is that they're using __init__ inside the unflatten rule of their pytrees: https://github.com/wilson-labs/cola/blob/50d61041b94b0f092e89cc05169538b6e9caeeba/cola/ops/operator_base.py#L250

This is a common mistake when implementing custom pytrees in JAX; see the documentation here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

I'd recommend reporting the bug to CoLA.

They could fix this by either (a) having their LinearOperator inherit from eqx.Module, or (b) unflattening via __new__ and __setattr__; see the Equinox implementation here for reference.


Incidentally, it's interesting to see another JAX-compatible library for linear operators. We also have Lineax, which is a part of the JAx+Equinox ecosystem, and provides both linear operators and linear solves. It looks like the two libaries have substantially overlapping (but not identical) functionality. For example Lineax supports pseudoinverses, but CoLA supports eigendecompositions.

adam-hartshorne commented 1 year ago

Thanks for the fast response. I will report it on their github.