JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Implicit gradient failing with matrices #137

Closed theogf closed 4 years ago

theogf commented 4 years ago

Here is a MWE:

using KernelFunctions, Flux, LinearAlgebra
k = transform(SqExponentialKernel(), 2.0)
ps = Flux.params(k)
X = rand(10, 1); x = vec(X)
A = rand(10, 10)
g = gradient(ps) do
  tr(kernelmatrix(k, X, obsdim = 1) * A)
end
g[ps[1]] == nothing

g2 = gradient(k) do k
  tr(kernelmatrix(k, X, obsdim = 1) * A)
end
g2[1].transform.s != nothing

g3 = gradient(ps) do
  tr(kernelmatrix(k, x) * A)
end
g3[ps[1]] != nothing

I think this is related to https://github.com/FluxML/Zygote.jl/issues/692 Any idea on how to solve this @willtebbutt ? It is probably connected to the ColVecs structure

willtebbutt commented 4 years ago

This is outside my area of expertise I'm afraid.

theogf commented 4 years ago

I think there is a general issue with the adjoint of ColVecs/RowVecs, do you know who could help with it?

willtebbutt commented 4 years ago

Have you tried wrapping everything in a let block? Globals are hard, so it's possible that Zygote is buggy w.r.t. them.

edit: I'm not sure exactly how the ColVecs etc pullbacks would affect this. If they work under usual circumstances, I would expect them to work here 🤷‍♂️

theogf commented 4 years ago

You mean this ?

let kernel = k
  g = gradient(ps) do
    tr(kernelmatrix(kernel, X, obsdim = 1) * A)
  end
end
willtebbutt commented 4 years ago

Nah, just

using KernelFunctions, Flux, LinearAlgebra

let

k = transform(SqExponentialKernel(), 2.0)
ps = Flux.params(k)
X = rand(10, 1); x = vec(X)
A = rand(10, 10)
g = gradient(ps) do
  tr(kernelmatrix(k, X, obsdim = 1) * A)
end
g[ps[1]] == nothing

g2 = gradient(k) do k
  tr(kernelmatrix(k, X, obsdim = 1) * A)
end
g2[1].transform.s != nothing

g3 = gradient(ps) do
  tr(kernelmatrix(k, x) * A)
end
g3[ps[1]] != nothing

end
theogf commented 4 years ago

Nope same behavior

theogf commented 4 years ago

I found a fix \o/ ! I think we should avoid relying on Base.map, removing it and replacing it directly by _map solves the problem. I think this is connected to https://github.com/FluxML/Zygote.jl/issues/522 which you and Mike already looked at apparently. Also it looks like the adjoints for ColVecs and RowVecs are not necessary. I will make a PR with a fix