MichielStock / Kronecker.jl

A general-purpose toolbox for efficient Kronecker-based algebra.
MIT License
86 stars 14 forks source link

AD rules that apply to KroneckerProducts #92

Open elisno opened 3 years ago

elisno commented 3 years ago

(Related to #11)

I'm trying to wrap my head around getting gradients with kron/kronecker.

  1. Is it sufficient to define custom AD rules for the vec-trick with ChainRulesCore.jl
function rrule(::typeof(*), K::KroneckerProduct, x::AbstractVector)
    function times_vec_pullback(ΔΩ)
        ...
    end
    return K*x, times_vec_pullback
end

function rrule(::typeof(*), K::KroneckerProduct, X::AbstractMatrix)
    function times_mat_pullback(ΔΩ)
        ...
    end
    return K*X, times_mat_pullback
end
  1. Do we also need to define rules for the constructor as well to get gradients?
function rrule(::typeof(kronecker), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_pullback(ΔΩ)
        ...
    end
    return kronecker(A, B), kronecker_pullback
end
  1. Should the pullbacks also be lazy? I found this to be a decent overview on finding vectorized derivatives. Would the pullbacks then just be reshape rules for these vectorized derivatives?
MichielStock commented 3 years ago

The question might be what is fixed and what you might want to compute the derivative of. I originally conceived Kronecker to work with systems as f(K * w) where you might want to optimize w as a parameter matrix. This should be easy enough.

Taking the gradients of the Kronecker matrix itself would be a A ⊗ B => I ⊗ B and A ⊗ I.

Maybe the dot(x, A, y) might also be a special case?

I have been working with ChainRulesCore, so you might open a PR and we can look together?

elisno commented 3 years ago

Taking the gradients of the Kronecker matrix itself would be a A ⊗ B => I ⊗ B and A ⊗ I.

It doesn't appear to be quite that straight forward. Care must be taken on setting the appropriate size of I for each partial derivative. A (conjugate?) transpose needs to be take of some of the matrices.

I have been working with ChainRulesCore, so you might open a PR and we can look together?

I've managed to put together a semi-working example with the eager kron and Zygote.gradient. I'd have to review how I do the first steps with the chain-rule. I'll open a PR today.

using LinearAlgebra
using Random
using Zygote

M, N = 3, 2
n_samples = 3

Random.seed!(0)
A = rand(1, N)
B = rand(1, M)
x = rand(M*N, n_samples)
y = rand(n_samples)

model(A, B, X) = kron(A, B) * X

function loss(A, B, X)
    Z = model(A, B, X) - y'
    L = 0.5 * Z * Z'
    return L[1]
end

function gradient_A(A, B, x)
    Z = model(A, B, x) - y'
    n = size(A, 2)
    IA_col = Diagonal(ones(n))
    return Z * (kron(IA_col', B) * x)'
end

function gradient_B(A, B, x)
    Z = model(A, B, x) - y'
    n = size(B, 2)
    IB_col = Diagonal(ones(n))
    return  Z * (kron(A, IB_col) * x)'
end

# Compare hand-written gradients with running Zygote.gradient on the loss function
@assert gradient_A(A, B, x) ≈ gradient(loss, A, B, x)[1]
@assert gradient_B(A, B, x) ≈ gradient(loss, A, B, x)[2]

# Show partial derivatives of the loss function w.r.t. to the Kronecker-factors.
@show gradient(loss, A, B, x)[1:2]

Maybe the dot(x, A, y) might also be a special case?

What did you have in mind for this?