JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Issue with gradients of Periodic Matern-1/2 #474

Closed wil-j-wil closed 1 year ago

wil-j-wil commented 2 years ago

I have found a case in which the gradient (using Zygote) of a kernel with respect to its lengthscale seems to be incorrect: a Matern-1/2 kernel with a periodic transform.

Here's an example:

using FiniteDifferences, KernelFunctions, Zygote
using KernelFunctions: kernelmatrix

X = collect(1.0:100.0);

function kernel(θ)
    return with_lengthscale(Matern12Kernel(), θ) ∘ PeriodicTransform(1 / 24.0)
end

objective(θ) = sum(kernelmatrix(kernel(θ), X))

len = 2.0;

∇zyg = Zygote.gradient(objective, len)  # = 1083.35
∇fd = FiniteDifferences.grad(central_fdm(5, 1), objective, len)  # = 1485.21

Note that removing the periodic transform or using a different Matern kernel results in the gradients being correct.

Any ideas what's going on here?

Edit: this bug seems to have been introduced in version 0.10.42, i.e., PR #466

wil-j-wil commented 2 years ago

Further info:

Using an equivalent FunctionTransform instead gives the correct result, but is much less efficient:

f(x) = [sinpi(2 / period * x), cospi(2 / period * x)]
kernel(θ) = with_lengthscale(Matern12Kernel(), θ) ∘ FunctionTransform(f)

Using version 0.10.44 I get the following with the PeriodicTransform:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  264.250 μs …   2.547 ms  ┊ GC (min … max):  0.00% … 84.84%
 Time  (median):     274.625 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   321.500 μs ± 295.051 μs  ┊ GC (mean ± σ):  13.45% ± 12.59%

  █▃                                                          ▁ ▁
  ██▅▄▅▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▇█ █
  264 μs        Histogram: log(frequency) by time       2.25 ms <

 Memory estimate: 1014.02 KiB, allocs estimate: 950.

and then with the FunctionTransform:

BenchmarkTools.Trial: 629 samples with 1 evaluation.
 Range (min … max):  6.443 ms … 13.104 ms  ┊ GC (min … max):  0.00% … 42.89%
 Time  (median):     6.720 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   7.955 ms ±  2.346 ms  ┊ GC (mean ± σ):  14.74% ± 17.76%

  ▄█▇                                                         
  ███▇▄▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▃▄▃▅▅▄▃▂ ▃
  6.44 ms        Histogram: frequency by time        12.7 ms <

 Memory estimate: 9.82 MiB, allocs estimate: 137921.

Digging deeper, it seems to be the use of _map rather than map in this line (_map speeds things up, but also causes this bug).

wil-j-wil commented 2 years ago

Seems to be a type stability issue when using sinpi / cospi? The following hack fixes things:

_pi = Float32(pi)

function KernelFunctions._map(t::PeriodicTransform, x::AbstractVector{<:Real})
    return RowVecs(hcat(sin.((2 * _pi * only(t.f)) .* x), cos.((2 * _pi * only(t.f)) .* x)))
end

Edit: but, weirdly, _pi = Float64(pi) and _pi = eltype(x)(pi) don't give the correct result. Maybe it's not type stability, but rather some obscure numerical issue?

willtebbutt commented 2 years ago

Hmmm ths is very very strange. It seems like it might be some kind of numerical thing. Could you confirm that if you modify X to something like X = 1.0:1.147:100.0, that you get roughly consistent answers?

wil-j-wil commented 2 years ago

yep, adding even just a tiny increment fixes things, e.g. X=1.0:1.001:100.0.

willtebbutt commented 1 year ago

Hmmm I'm wondering whether this is just a thing because the Matern-1/2 kernel is non-differentiable w.r.t. its inputs at zero? Maybe the finite differencing algorithm is just smoothing something out?

willtebbutt commented 1 year ago

But then that wouldn't explain changes between versions... hmmmm

willtebbutt commented 1 year ago

Right, I've figured it out! It is kind of because the Matern-1/2 kernel is non-differentiable at 0.

In particular, the Euclidean distance is non-differentiable at the point where the distance between two inputs is zero due to the square root.

In the previous version, map(k.transform, X) would output a Vector{Vector{Float64}}, while in the current version outputs a RowVecs{Float64}. This change is in principle fine, but means that different versions of the pullback for pairwise(::Euclidean, ::AbstractVector) are getting hit. In the new version, we hit this Zygote rule that I implemented a while ago, while in the old version we hit something else, I'm not entirely sure what.

In any case, it means that we need to fix whatever is going on in Zygote / improve its numerics.

The following implementation of the adjoint provides a temporary fix:

@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
  function _pairwise_euclidean(X)
    δ = eps(eltype(X))^2
    return sqrt.(max.(pairwise(SqEuclidean(), X; dims=dims), δ))
  end
  D, back = pullback(_pairwise_euclidean, X)

  return D, function(Δ)
    return (nothing, back(Δ)...)
  end
end

I'll have to think a bit harder about a good long-term solution. This issue crops up once every couple of years. Fingers crossed it gets solved properly this time, and we can wait at least a decade before having to touch the implementation again...

willtebbutt commented 1 year ago

Okay, I think I've fixed it properly in https://github.com/FluxML/Zygote.jl/pull/1307

@wil-j-wil if you have 5 mins, could you dev --local Zygote, check out to the above branch (you'll probably have to add my fork as a remote) and confirm that it does indeed solve your problem? I've tested on my end, but I'd like to confirm that it's working for you before I consider the issue resolved.

wil-j-wil commented 1 year ago

yes this works for me too! thanks @willtebbutt

willtebbutt commented 1 year ago

Great. If you upgrade to the latest patch release of Zygote (should be available in an hour or so), the problem should be fixed. Thanks again for raising this problem @wil-j-wil and digging in to it -- it's a good one to have fixed.