Open elisno opened 3 years ago
The question might be what is fixed and what you might want to compute the derivative of. I originally conceived Kronecker to work with systems as f(K * w)
where you might want to optimize w
as a parameter matrix. This should be easy enough.
Taking the gradients of the Kronecker matrix itself would be a A ⊗ B
=> I ⊗ B
and A ⊗ I
.
Maybe the dot(x, A, y)
might also be a special case?
I have been working with ChainRulesCore, so you might open a PR and we can look together?
Taking the gradients of the Kronecker matrix itself would be a
A ⊗ B
=>I ⊗ B
andA ⊗ I
.
It doesn't appear to be quite that straight forward. Care must be taken on setting the appropriate size of I
for each partial derivative. A (conjugate?) transpose needs to be take of some of the matrices.
I have been working with ChainRulesCore, so you might open a PR and we can look together?
I've managed to put together a semi-working example with the eager kron
and Zygote.gradient
. I'd have to review how I do the first steps with the chain-rule. I'll open a PR today.
using LinearAlgebra
using Random
using Zygote
M, N = 3, 2
n_samples = 3
Random.seed!(0)
A = rand(1, N)
B = rand(1, M)
x = rand(M*N, n_samples)
y = rand(n_samples)
model(A, B, X) = kron(A, B) * X
function loss(A, B, X)
Z = model(A, B, X) - y'
L = 0.5 * Z * Z'
return L[1]
end
function gradient_A(A, B, x)
Z = model(A, B, x) - y'
n = size(A, 2)
IA_col = Diagonal(ones(n))
return Z * (kron(IA_col', B) * x)'
end
function gradient_B(A, B, x)
Z = model(A, B, x) - y'
n = size(B, 2)
IB_col = Diagonal(ones(n))
return Z * (kron(A, IB_col) * x)'
end
# Compare hand-written gradients with running Zygote.gradient on the loss function
@assert gradient_A(A, B, x) ≈ gradient(loss, A, B, x)[1]
@assert gradient_B(A, B, x) ≈ gradient(loss, A, B, x)[2]
# Show partial derivatives of the loss function w.r.t. to the Kronecker-factors.
@show gradient(loss, A, B, x)[1:2]
Maybe the dot(x, A, y) might also be a special case?
What did you have in mind for this?
(Related to #11)
I'm trying to wrap my head around getting gradients with
kron
/kronecker
.