patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

Filter Primitive Transpose for trace of linear operators #847

Closed quattro closed 2 months ago

quattro commented 2 months ago

Hi πŸ‘‹ ,

I'm not sure if this is the best place or my question, or lineax, but I've been using the equinox internal API to define a trace primitive for our traceax package, which builds upon the linear operators defined in lineax.

Using the API works like a charm for most of the required definitions (e.g., abstract_eval, jvp, etc.), however I am having a bit of difficulty getting the transpose definition to work with general linear operators, whose internal shape/structure may change.

Without worrying about the estimators for the trace function and focusing on its definition, the Jacobian for trace(A) should result in the identity matrix (or ravel(I) if we're being pedantic). Thus the JVP is the inner product between <d tr(A) / d A, V> = <I, V> = tr(V). Okay, JVP works like a charm.

VJP, by its definition using the Jacobian and transpose should be J'v, which here results in I * ct, where ct is the cotangent of the primal output.

The issue I'm running into is that defining the transpose rule as,

@eqxi.filter_primitive_transpose(materialise_zeros=True)  # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
    key, operator, state, k, estimator = inputs
    cts_result, cts_stats = cts_out
    in_struct = operator.in_structure()       # cannot do! jacobian is linear wrt operator and internals are UndefinedPrimals!
    out_struct = operator.out_structure()  # cannot do! jacobian is linear wrt operator and internals are UndefinedPrimals!
    op_t = cts_result * lx.IdentityLinearOperator(in_struct, out_struct)
    state_none = jtu.tree_map(lambda _: None, state)
    k_none = None
    estimator_none = jtu.tree_map(lambda _: None, estimator)

    return key_none, op_t, state_none, k_none, estimator_none

with eqxi.Flatten.__call__ complaining,

    assert jtu.tree_structure(out, is_leaf=_is_none) == jtu.tree_structure(
        like, is_leaf=_is_none
    )

throwing an assertion error if I depend on the specific operator structure. For example, given MatrixLinearOperator one could pull shape from the abstract value, for instance, operator.matrix.aval.shape given,

@eqxi.filter_primitive_transpose(materialise_zeros=True)  # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
    key, operator, state, _, estimator = inputs
    cts_result, _ = cts_out
    # HACK ALERT
    n = operator.matrix.aval.shape[0]
    dtype = operator.matrix.aval.dtype
    struct = jax.ShapeDtypeStruct(shape=(n,), dtype=dtype)
    op_t = cts_result * IdentityLinearOperator(struct, struct)
    key_none = jtu.tree_map(lambda _: None, key)
    state_none = jtu.tree_map(lambda _: None, state)
    k_none = None
    estimator_none = jtu.tree_map(lambda _: None, estimator)

    return key_none, op_t, state_none, k_none, estimator_none

I am fairly sure the output structure from the _estimate_trace_transpose is the same PyTree structure as the primitive (with Nones defined appropriately), but cannot quite understand what shape I should match to.

patrick-kidger commented 2 months ago

Oh, interesting! I'm definitely curious how you find using this API. Right now the eqxi.filter_primitive_* interface exists only for lineax.linear_solve_p -- and it's definitely an advanced API surface -- so I think you're breaking new ground here.

As for the problem you're facing, my first guess is that something like the following may be appropriate:

def _remove_undefined_primal(x):
    if type(x) is UndefinedPrimal:
        return x.aval
    else:
        return

def _estimate_trace_transpose(...):
    operator_struct = jtu.tree_map(_remove_undefined_primal, operator)
    in_structure = eqxi.filter_eval_shape(lambda o: o.in_structure(), operator_struct)
    out_structure = eqxi.filter_eval_shape(lambda o: o.out_structure(), operator_struct)
    ...

WDYT?

quattro commented 2 months ago

Thanks so much for the speedy reply, as always!

Your code (with minor modifications to match imports, leaf checking) pulls the structure no problem!

@eqxi.filter_primitive_transpose(materialise_zeros=True)  # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
    # the jacobian, for the trace is just the identity matrix, i.e. J = I
    # so J'v = I v = v

    # primal inputs; operator should have UndefinedPrimal leaves
    key, operator, state, _, estimator = inputs

    # co-tangent of the trace approximation and the stats (None)
    cts_result, _ = cts_out

    # the internals of the operator are UndefinedPrimal leaves so
    # we need to rely on abstract values to pull structure info
    operator_struct = jtu.tree_map(_remove_undefined_primal, operator, is_leaf=_is_undefined)
    in_structure = eqx.filter_eval_shape(lambda o: o.in_structure(), operator_struct)
    out_structure = eqx.filter_eval_shape(lambda o: o.out_structure(), operator_struct)
    op_t = cts_result * IdentityLinearOperator(in_structure, out_structure)

    key_none = jtu.tree_map(lambda _: None, key)
    state_none = jtu.tree_map(lambda _: None, state)
    k_none = None
    estimator_none = jtu.tree_map(lambda _: None, estimator)

    return key_none, op_t, state_none, k_none, estimator_none

Given this wonderful fix, there is a downstream issue with the Flatten functionality in equinox internals. An assertion error is thrown at line 116 due to different PyTree structures.

For instance, if the primal operator was MatrixLinearOperator for example, the returned cotangent operator is MulLinearOperator due to the cotangent of the trace scaling the IdentityLinearOperator.

Revealing the relevant part of the tree structures we see,

CustomNode(MatrixLinearOperator[('matrix',), ('tags',), (frozenset(),)], [*])

vs

 CustomNode(MulLinearOperator[('operator', 'scalar'), (), ()], [CustomNode(IdentityLinearOperator[(), ('input_structure', 'output_structure'), (([ShapeDtypeStruct(shape=(50,), dtype=float32)], PyTreeDef(*)), ([ShapeDtypeStruct(shape=(50,), dtype=float32)], PyTreeDef(*)))], []), *])

I'm not sure I see a clean workaround given the necessity of this check for more general cases, and the simplifying case of the linear algebra trace function.

This has been massively helpful in understanding both JVP/VJP and the internals of equinox.

patrick-kidger commented 2 months ago

Ah, that might take a bit more work :)

Matching structures here is a JAX constraint. Moreover they're not treated as linear operators at all, all that JAX cares about is matching the primal and tangents of pytrees-of-arrays.

quattro commented 2 months ago

Ooof, okay. This is certainly disappointing. I appreciate your help, and thanks again for the incredibly useful libraries. Cheers.