Closed nahid18 closed 1 month ago
Hey there! I've been enjoying watching your progress with Traceax.
That said, I'm afraid this probably isn't an issue I can help you debug. I'm already constrained finding enough time to do my own projects, let alone working through thorny details in someone else's! :)
Hi 👋 ,
I'm not sure if this is the best place for my question but I've been working on this project (#847) with @quattro where we are using the
equinox
internal API (eqxi
) to define a trace primitive for our traceax package, which utilizes the linear operators oflineax
.The issue I'm running into is that our custom trace estimation implementation is producing results that are 2 * the expected value compared to using
jax.numpy.trace
for tangent trace estimation.The issue gets fixed if I replace (in the jvp implementation below) this
with this
jvp implementation
transpose implementation
minimal test case that demonstrates the issue:
The resulting gradient should be an identity matrix, but our code is producing a matrix with all elements equal to 2 on the diagonal.
I am not quite sure what I am missing here. Your input would be appreciated. Thank you so much!