patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

custom trace estimation resulting in 2 * expected result #875

Closed nahid18 closed 1 month ago

nahid18 commented 1 month ago

Hi 👋 ,

I'm not sure if this is the best place for my question but I've been working on this project (#847) with @quattro where we are using the equinox internal API (eqxi) to define a trace primitive for our traceax package, which utilizes the linear operators of lineax.

The issue I'm running into is that our custom trace estimation implementation is producing results that are 2 * the expected value compared to using jax.numpy.trace for tangent trace estimation.

The issue gets fixed if I replace (in the jvp implementation below) this

t_state = estimator.init(t_key, t_operator)
t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator)

with this

t_result = jax.numpy.trace(t_operator.as_matrix())

jvp implementation

@eqxi.filter_primitive_jvp
def _estimate_trace_jvp(primals, tangents):
    key, operator, state, k, estimator = primals
    # t_operator := V
    t_key, t_operator, t_state, t_k, t_estimator = tangents
    jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator))
    del t_key, t_state, t_k, t_estimator

    # primal problem of t = tr(A)
    result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator)
    out = result, stats

    # inner prodct in linear operator space => <A, B> = tr(A @ B)
    # d tr(A) / dA = I
    # t' = <tr'(A), V> = tr(I @ V) = tr(V)
    # tangent problem => tr(V)
    key, t_key = rdm.split(key)
    if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)):
        t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none)
        t_operator = lx.TangentLinearOperator(operator, t_operator)

    t_state = estimator.init(t_key, t_operator)
    t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator)
    t_out = (
        t_result,
        jtu.tree_map(lambda _: None, stats),
    )

    return out, t_out

transpose implementation

@eqxi.filter_primitive_transpose(materialise_zeros=True)  # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
    # the jacobian, for the trace is just the identity matrix, i.e. J = I
    # so J'v = I v = v

    # primal inputs; operator should have UndefinedPrimal leaves
    key, operator, state, _, estimator = inputs

    # co-tangent of the trace approximation and the stats (None)
    cts_result, _ = cts_out

    # the internals of the operator are UndefinedPrimal leaves so
    # we need to rely on abstract values to pull structure info
    op_t = _make_identity(operator, cts_result)

    key_none = jtu.tree_map(lambda _: None, key)
    state_none = (None, op_t, None)
    k_none = None
    estimator_none = jtu.tree_map(lambda _: None, estimator)

    return key_none, op_t, state_none, k_none, estimator_none

minimal test case that demonstrates the issue:

import jax
import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import traceax as tx

SEED = 0
N = 50
K = 5
key = rdm.PRNGKey(SEED)

def _tr(op, k):
    result = tx.trace(key, op, k)
    return result.value

op = lx.MatrixLinearOperator(jnp.eye(N))
g = jax.grad(_tr)(op, K)
print(f"Grad of mat {op} is {g.as_matrix()}")

The resulting gradient should be an identity matrix, but our code is producing a matrix with all elements equal to 2 on the diagonal.

I am not quite sure what I am missing here. Your input would be appreciated. Thank you so much!

patrick-kidger commented 1 month ago

Hey there! I've been enjoying watching your progress with Traceax.

That said, I'm afraid this probably isn't an issue I can help you debug. I'm already constrained finding enough time to do my own projects, let alone working through thorny details in someone else's! :)