Note: the current version of this package is not intended for general consumption.
DiffLinearAlgebra can be (very loosely) thought of as DiffRules.jl for linear algebra. For every sensitivity, we provide a function which, when provided with the input and output from the forward pass and the reverse-mode sensitvity w.r.t the output from the forward pass, computes the sensitivity of the specified argument.
A, B = randn(5, 3), randn(3, 4)
C, C̄ = A * B, randn(5, 4)
Ā = ∇(*, Val{1}, (), C, C̄, A, B)
B̄ = ∇(*, Val{2}, (), C, C̄, A, B)
In the above example, the sensitivities of A
and B
are computed from C
and a random seeding of C̄
. (Note that the third argument is currently redundant; see this issue for motivation for its inclusion.)
We also expose some "metadata" for each implemented sensitivity. This is done via a set called ops
contains DiffOp structs. These structs contain information regarding the arguments types supported by each sensitivity, and which arguments are differentiable.