JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Slow mode for kerneldiagmatrix #203

Closed theogf closed 3 years ago

theogf commented 3 years ago

Zygote currently fails to differentiate through kerneldiagmatrix when given a RowVecs or a ColVecs

MWE :

using KernelFunctions, Zygote

X = KernelFunctions.RowVecs(rand(3, 3))
k = transform(SqExponentialKernel(), 2.0)
Zygote.gradient(k) do k
    sum(kerneldiagmatrix(k, X))
end

ERROR: In slow method
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::KernelFunctions.var"#back#186")(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/KernelFunctions/V02nz/src/zygote_adjoints.jl:75
 [3] (::KernelFunctions.var"#171#back#187"{KernelFunctions.var"#back#186"})(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] _map at /home/theo/.julia/packages/KernelFunctions/V02nz/src/transform/scaletransform.jl:26 [inlined]
 [5] (::typeof(∂(_map)))(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0
 [6] kerneldiagmatrix at /home/theo/.julia/packages/KernelFunctions/V02nz/src/kernels/transformedkernel.jl:85 [inlined]
 [7] (::typeof(∂(kerneldiagmatrix)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0
willtebbutt commented 3 years ago

Right. So what's happening is on the reverse-pass, the gradient w.r.t. the output of

_map(κ.transform, x)

found here is a Vector{Vector{Float64}}, which in turn hits the adjoint for the RowVecs ctor found here.

When I first wrote the ColVecs stuff in Stheno, I actively wanted to avoid ever accidentally hitting the slow-path involving pulling out individual elements of a ColVecs, so made the adjoint error if this ever happened (i.e. it got an actually vector-of-vectors as a derivative).

This isn't an issue in Stheno, which suggests that we're missing / have opted not to implement something here. I believe that the issue is that we've not implemented a specialised version of kerneldiagmatix for SimpleKernels. Since all SimpleKernels rely on a metric, for which colwise exists, we can do this as something like

function kerneldiagmatrix(k::SimpleKernel, x::AbstractVector)
    return map(d -> kappa(κ, d), pairwise(metric(κ), x))
end
function kerneldiagmatrix(k::SimpleKernel, x::AbstractVector, y::AbstractVector)
    return map(d -> kappa(κ, d), colwise(metric(κ), x, y))
end

with additional error checking. This should both get rid of this issue and improve performance of kerneldiagmatrix for SimpleKernels.

Does this seem reasonable @theogf ?

willtebbutt commented 3 years ago

It's also a bit surprising that this wasn't picked up by our tests, as it suggests that we're not testing some pretty basic kernel combinations properly.

theogf commented 3 years ago

If your solution works I am all for it! Did you try it out?