Closed wil-j-wil closed 1 year ago
Further info:
Using an equivalent FunctionTransform
instead gives the correct result, but is much less efficient:
f(x) = [sinpi(2 / period * x), cospi(2 / period * x)]
kernel(θ) = with_lengthscale(Matern12Kernel(), θ) ∘ FunctionTransform(f)
Using version 0.10.44
I get the following with the PeriodicTransform
:
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 264.250 μs … 2.547 ms ┊ GC (min … max): 0.00% … 84.84%
Time (median): 274.625 μs ┊ GC (median): 0.00%
Time (mean ± σ): 321.500 μs ± 295.051 μs ┊ GC (mean ± σ): 13.45% ± 12.59%
█▃ ▁ ▁
██▅▄▅▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▇█ █
264 μs Histogram: log(frequency) by time 2.25 ms <
Memory estimate: 1014.02 KiB, allocs estimate: 950.
and then with the FunctionTransform
:
BenchmarkTools.Trial: 629 samples with 1 evaluation.
Range (min … max): 6.443 ms … 13.104 ms ┊ GC (min … max): 0.00% … 42.89%
Time (median): 6.720 ms ┊ GC (median): 0.00%
Time (mean ± σ): 7.955 ms ± 2.346 ms ┊ GC (mean ± σ): 14.74% ± 17.76%
▄█▇
███▇▄▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▃▄▃▅▅▄▃▂ ▃
6.44 ms Histogram: frequency by time 12.7 ms <
Memory estimate: 9.82 MiB, allocs estimate: 137921.
Digging deeper, it seems to be the use of _map
rather than map
in this line (_map
speeds things up, but also causes this bug).
Seems to be a type stability issue when using sinpi
/ cospi
? The following hack fixes things:
_pi = Float32(pi)
function KernelFunctions._map(t::PeriodicTransform, x::AbstractVector{<:Real})
return RowVecs(hcat(sin.((2 * _pi * only(t.f)) .* x), cos.((2 * _pi * only(t.f)) .* x)))
end
Edit: but, weirdly, _pi = Float64(pi)
and _pi = eltype(x)(pi)
don't give the correct result. Maybe it's not type stability, but rather some obscure numerical issue?
Hmmm ths is very very strange. It seems like it might be some kind of numerical thing. Could you confirm that if you modify X
to something like X = 1.0:1.147:100.0
, that you get roughly consistent answers?
yep, adding even just a tiny increment fixes things, e.g. X=1.0:1.001:100.0
.
Hmmm I'm wondering whether this is just a thing because the Matern-1/2 kernel is non-differentiable w.r.t. its inputs at zero? Maybe the finite differencing algorithm is just smoothing something out?
But then that wouldn't explain changes between versions... hmmmm
Right, I've figured it out! It is kind of because the Matern-1/2 kernel is non-differentiable at 0.
In particular, the Euclidean distance is non-differentiable at the point where the distance between two inputs is zero due to the square root.
In the previous version, map(k.transform, X)
would output a Vector{Vector{Float64}}
, while in the current version outputs a RowVecs{Float64}
. This change is in principle fine, but means that different versions of the pullback for pairwise(::Euclidean, ::AbstractVector)
are getting hit. In the new version, we hit this Zygote rule that I implemented a while ago, while in the old version we hit something else, I'm not entirely sure what.
In any case, it means that we need to fix whatever is going on in Zygote / improve its numerics.
The following implementation of the adjoint provides a temporary fix:
@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
function _pairwise_euclidean(X)
δ = eps(eltype(X))^2
return sqrt.(max.(pairwise(SqEuclidean(), X; dims=dims), δ))
end
D, back = pullback(_pairwise_euclidean, X)
return D, function(Δ)
return (nothing, back(Δ)...)
end
end
I'll have to think a bit harder about a good long-term solution. This issue crops up once every couple of years. Fingers crossed it gets solved properly this time, and we can wait at least a decade before having to touch the implementation again...
Okay, I think I've fixed it properly in https://github.com/FluxML/Zygote.jl/pull/1307
@wil-j-wil if you have 5 mins, could you dev --local Zygote
, check out to the above branch (you'll probably have to add my fork as a remote) and confirm that it does indeed solve your problem? I've tested on my end, but I'd like to confirm that it's working for you before I consider the issue resolved.
yes this works for me too! thanks @willtebbutt
Great. If you upgrade to the latest patch release of Zygote (should be available in an hour or so), the problem should be fixed. Thanks again for raising this problem @wil-j-wil and digging in to it -- it's a good one to have fixed.
I have found a case in which the gradient (using Zygote) of a kernel with respect to its lengthscale seems to be incorrect: a Matern-1/2 kernel with a periodic transform.
Here's an example:
Note that removing the periodic transform or using a different Matern kernel results in the gradients being correct.
Any ideas what's going on here?
Edit: this bug seems to have been introduced in version
0.10.42
, i.e., PR #466