Closed quattro closed 2 months ago
Oh, interesting! I'm definitely curious how you find using this API. Right now the eqxi.filter_primitive_*
interface exists only for lineax.linear_solve_p
-- and it's definitely an advanced API surface -- so I think you're breaking new ground here.
As for the problem you're facing, my first guess is that something like the following may be appropriate:
def _remove_undefined_primal(x):
if type(x) is UndefinedPrimal:
return x.aval
else:
return
def _estimate_trace_transpose(...):
operator_struct = jtu.tree_map(_remove_undefined_primal, operator)
in_structure = eqxi.filter_eval_shape(lambda o: o.in_structure(), operator_struct)
out_structure = eqxi.filter_eval_shape(lambda o: o.out_structure(), operator_struct)
...
WDYT?
Thanks so much for the speedy reply, as always!
Your code (with minor modifications to match imports, leaf checking) pulls the structure no problem!
@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
# the jacobian, for the trace is just the identity matrix, i.e. J = I
# so J'v = I v = v
# primal inputs; operator should have UndefinedPrimal leaves
key, operator, state, _, estimator = inputs
# co-tangent of the trace approximation and the stats (None)
cts_result, _ = cts_out
# the internals of the operator are UndefinedPrimal leaves so
# we need to rely on abstract values to pull structure info
operator_struct = jtu.tree_map(_remove_undefined_primal, operator, is_leaf=_is_undefined)
in_structure = eqx.filter_eval_shape(lambda o: o.in_structure(), operator_struct)
out_structure = eqx.filter_eval_shape(lambda o: o.out_structure(), operator_struct)
op_t = cts_result * IdentityLinearOperator(in_structure, out_structure)
key_none = jtu.tree_map(lambda _: None, key)
state_none = jtu.tree_map(lambda _: None, state)
k_none = None
estimator_none = jtu.tree_map(lambda _: None, estimator)
return key_none, op_t, state_none, k_none, estimator_none
Given this wonderful fix, there is a downstream issue with the Flatten
functionality in equinox internals. An assertion error is thrown at line 116 due to different PyTree structures.
For instance, if the primal operator was MatrixLinearOperator
for example, the returned cotangent operator is MulLinearOperator
due to the cotangent of the trace scaling the IdentityLinearOperator
.
Revealing the relevant part of the tree structures we see,
CustomNode(MatrixLinearOperator[('matrix',), ('tags',), (frozenset(),)], [*])
vs
CustomNode(MulLinearOperator[('operator', 'scalar'), (), ()], [CustomNode(IdentityLinearOperator[(), ('input_structure', 'output_structure'), (([ShapeDtypeStruct(shape=(50,), dtype=float32)], PyTreeDef(*)), ([ShapeDtypeStruct(shape=(50,), dtype=float32)], PyTreeDef(*)))], []), *])
I'm not sure I see a clean workaround given the necessity of this check for more general cases, and the simplifying case of the linear algebra trace function.
This has been massively helpful in understanding both JVP/VJP and the internals of equinox.
Ah, that might take a bit more work :)
Matching structures here is a JAX constraint. Moreover they're not treated as linear operators at all, all that JAX cares about is matching the primal and tangents of pytrees-of-arrays.
Ooof, okay. This is certainly disappointing. I appreciate your help, and thanks again for the incredibly useful libraries. Cheers.
Hi π ,
I'm not sure if this is the best place or my question, or lineax, but I've been using the equinox internal API to define a trace primitive for our traceax package, which builds upon the linear operators defined in lineax.
Using the API works like a charm for most of the required definitions (e.g., abstract_eval, jvp, etc.), however I am having a bit of difficulty getting the transpose definition to work with general linear operators, whose internal shape/structure may change.
Without worrying about the estimators for the trace function and focusing on its definition, the Jacobian for trace(A) should result in the identity matrix (or ravel(I) if we're being pedantic). Thus the JVP is the inner product between <d tr(A) / d A, V> = <I, V> = tr(V). Okay, JVP works like a charm.
VJP, by its definition using the Jacobian and transpose should be J'v, which here results in I * ct, where ct is the cotangent of the primal output.
The issue I'm running into is that defining the transpose rule as,
with
eqxi.Flatten.__call__
complaining,throwing an assertion error if I depend on the specific operator structure. For example, given MatrixLinearOperator one could pull shape from the abstract value, for instance,
operator.matrix.aval.shape
given,I am fairly sure the output structure from the
_estimate_trace_transpose
is the same PyTree structure as the primitive (with Nones defined appropriately), but cannot quite understand what shape I should match to.