Open FelixBenning opened 1 year ago
Patch coverage has no change and project coverage change: -16.75
:warning:
Comparison is base (
ef6d459
) 94.16% compared to head (deebf0c
) 77.41%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
I don't think ForwardDiff
should be an explicit dependency for KernelFunctions. To me this would make more sense as an extension (which might also allow for different implementations, i.e. Enzyme forward).
I have been playing around with the ideas in PR, and realized that to make this work there are some open questions.
The first issue is that kernelmatrix
is specialized for SimpleKernels
and MOKernels
, which would require some additional thought and changes to make work.
For SimpleKernels
just adding the _evaluate
method is insufficient, as kernelmatrix
uses pairwise(metric(...), ...
. In principle one could think about going deeper, and start extending all necessary methods for Distances.jl
, but at this point the question is whether it is worth it.
At this point a wrapper might be easier, because then the only thing needed are some additional methods. At this point GP(DiffWrapper(Kernel))
would indeed be a different object, since it might use be less specialized methods for the kernelmatrix
.
Additionally, GP gd
that expresses the derivative of some GP g
is not quite the same object. At least for the exact GP posteriors defined in AbstractGPs.jl
, each instance stores the Cholesky decomposition C
as well as C\y
for the undifferentiated input.kernel, which can then be efficiently re-used each time the mean
or var
are computed.
To get the variance var(::GP, ::DiffPt)
we can't use the existing C
, so we would need to compute the whole C
for d^2/dx1dx2 k(x1,x2)
or store this matrix in addition.
@Crown421 the kernelmatrix
thing is an issue I have not considered. Taking derivatives breaks isotropy leaving only stationarity intact. But I am also not sure why this specialization for kernelmatrix
function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end
is more performant than
function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return broadcast(κ, x, permutedims(x))
I mean pairwise(metric(κ), x) = broadcast(metric(κ), x, permutedims(x))
. So the specialized implementation does essentially
K = broadcast(x, permutedims(x) do (x,y)
metric(κ)(x,y)
end # first pass over K
map( x->kappa(κ, x), K) # second pass over K
which accesses the elements of K
twice. On the other hand
K = broadcast(x, permutedims(x) do (x,y)
κ(x,y)
end
only requires one access. Since memory access is typically the bottleneck, the general definition should be more performant. That is unless
(κ::SimpleKernel)(x,y) = kappa(κ, metric(κ)(x,y))
is not inlined and causes more function calls. But in that case it probably makes more sense to force inline the code above with @inline
this should ensure the general implementation would be reduced to
K = broadcast(x, permutedims(x) do (x,y)
kappa(κ, metric(κ)(x,y))
end
by the compiler. Which should be faster than the two pass version.
While DiffWrapper(kernel)
may be of type kernel
its compositions are not obvious. I mean for sums it is fine, since sum and differentiation commute. But for a function transform you do not have
\frac{d}{d x_i} k(f(x), f(y)) = (\frac{\partial}{\partial x_i} k) (f(x), f(y))
So
DiffWrapper(kernel) ∘ FunctionTransform(f) != DiffWrapper(kernel ∘ FunctionTransform(f))
So if you wanted to treat k=DiffWrapper(SqExponentialKernel())
as "the" mathematical squared exponential kernel
k(x,y) = \exp\Bigl(\frac{(x-y)^2}{2}\Bigr)
which is differentiable, then you would expect the behavior of DiffWrapper(kernel ∘ FunctionTransform(f))
. So you would have to specialize all the function composition for DiffWrapper
. And that feels like the main selling point of KernelFunctions.jl to me. Of course you could tell people to only use DiffWrapper
at the very end. But man is that ugly for zero reason.
I do not understand this point. This would automatically happen with this implementation too. I mean DiffPt(x, partial=i)
is just a special point $x +\partial_i$ which is not in $\mathbb{R}^n$. But it still has an evaluation $y$ and a row in the cholesky matrix which can be cached. Everything should just work as is with AbstractGP.jl
This is really weird...
julia> @btime kernelmatrix(k, Xc);
23.836 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa(k, x), KernelFunctions.pairwise(KernelFunctions.metric(k), Xc));
103.780 ms (8000009 allocations: 183.12 MiB)
julia> @btime k.(Xc, permutedims(Xc));
78.818 ms (4 allocations: 30.52 MiB)
julia> size(Xc)
(2000,)
julia> size(first(Xc))
(2,)
julia> k
Squared Exponential Kernel (metric = Distances.Euclidean(0.0))
Why is the performance of the implementation
function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end
https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/matrix/kernelmatrix.jl#L149-L151 worse than the function itself?
worse than the function itself?
Not sure what function you mean here and what you expect to be worse/better.
The main issue is that your benchmarking is flawed, variables etc. have to be interpolated since otherwise you suffer, sometimes massively, from type instabilities and inference issues introduced by global variables.
So instead you should perform benchmarks such as
julia> using KernelFunctions, BenchmarkTools
julia> Xc = ColVecs(randn(2, 2000));
julia> k = GaussianKernel();
julia> @btime kernelmatrix($k, $Xc);
38.594 ms (5 allocations: 61.05 MiB)
julia> @btime kernelmatrix($k, $Xc);
35.585 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
37.478 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
33.321 ms (5 allocations: 61.05 MiB)
julia> @btime $k.($Xc, permutedims($Xc));
45.019 ms (2 allocations: 30.52 MiB)
julia> @btime $k.($Xc, permutedims($Xc));
45.339 ms (2 allocations: 30.52 MiB)
I mean pairwise(metric(κ), x) = broadcast(metric(κ), x, permutedims(x)).
No, not generally. pairwise
for standard distances (such as Euclidean
) is implemented in highly optimized ways in Distances (e.g., by exploiting and ensuring symmetry of the distance matrix).
Since memory access is typically the bottleneck, the general definition should be more performant.
Therefore this statement also does not hold generally. If you are concerned about memory allocations, probably you also might want to use kernelmatrix!
instead of kernelmatrix
which minimizes allocations:
julia> # Continued from above
julia> K = Matrix{Float64}(undef, length(Xc), length(Xc));
julia> @btime kernelmatrix!($K, $k, $Xc);
25.775 ms (1 allocation: 15.75 KiB)
julia> @btime kernelmatrix!($K, $k, $Xc);
30.012 ms (1 allocation: 15.75 KiB)
Another disadvantage of broadcasting is that generally it means more work for the compiler (the whole broadcasting machinery is very involved and complicated) and hence increases compilation times.
@devmotion ahh 🤦 I only looked into the distances/pairwise.jl
file for the pairwise function. I did not know that Distances.jl
defines this as well. This is why I hate julias using
import mechanism and the include
instead of import of files. You never know where functions are coming from. It is basically like from module import *
in python which everyone dislikes for the same reason.
I guess if pairwise actually uses the symmetry of distances, then I see where the speedup in the isotropic case comes from.
Yes, I try to avoid using XX
in packages nowadays and rather use import XX
or using XX: f, g, h
to make such relations clearer (still convenient to use using XX
in the REPL IMO).
Caching the cholesky decomposition
I do not understand this point. This would automatically happen with this implementation too. I mean
DiffPt(x, partial=i)
is just a special point x+∂i which is not in Rn. But it still has an evaluation y and a row in the cholesky matrix which can be cached. Everything should just work as is withAbstractGP.jl
My apologies, I had an error in thinking here, I was convinced that an additional matrix would need to be cached, not sure why.
Issues with a wrapper
While
DiffWrapper(kernel)
may be of typekernel
its compositions are not obvious. I mean for sums it is fine, since sum and differentiation commute. But for a function transform you do not haveSo
DiffWrapper(kernel) ∘ FunctionTransform(f) != DiffWrapper(kernel ∘ FunctionTransform(f))
So if you wanted to treat
k=DiffWrapper(SqExponentialKernel())
as "the" mathematical squared exponential kernelwhich is differentiable, then you would expect the behavior of
DiffWrapper(kernel ∘ FunctionTransform(f))
. So you would have to specialize all the function composition forDiffWrapper
. And that feels like the main selling point of KernelFunctions.jl to me. Of course you could tell people to only useDiffWrapper
at the very end. But man is that ugly for zero reason.
Well, not zero reason. There are multiple reasons for using a wrapper in this PR, and therefore it comes down to opinion. It would be easy to define some fallback functions that throw an error in problematic cases, advising users to use the wrapper at the end.
Given that differentiable kernels would not be a core feature, but rather an Extension when also loading a compatible autodiff package, any changes in main part of KernelFunctions should be minimal, and not reduce any performance.
Therefore I would personally prefer starting with a wrapper, at least for now, to have the key functionality available and see additional issues during use. For example, I have already wondered:
How DiffPt
should be treated in combination with "normal" points. You mention above mixing the two, but what should for example vcat(ColVecs(X), DiffPt(x, partial=1))
look like? We get performance benefits from storing a points as columns/ rows of a matrix of concrete types (i.e. Matrix{Float64}
). Where do we put the partial
"annotation" of a DiffPt
? One option could be to define new types and a load of convenience functions to make it seamless to combine them with existing ones.
How does DiffPt
combine with MOInputs
?
For MOPinput
s there is a prepare_isotopic_multi_output_data
method, should there be something similar for DiffPt
s?
For me these are important usability questions, with a much higher "ugliness" potential than where one can put a wrapper. During a normal session, I manipulate a lot of inputs and input collections, but only define a GP/ kernel once.
@Crown421
Well, not zero reason. There are multiple reasons for using a wrapper in this PR, and therefore it comes down to opinion. It would be easy to define some fallback functions that throw an error in problematic cases, advising users to use the wrapper at the end.
I am starting to agree, given that I can not come up with a good solution to the kernelmatrix problem at the moment.
How DiffPt should be treated in combination with "normal" points. You mention above mixing the two, but what should for example vcat(ColVecs(X), DiffPt(x, partial=1)) look like? We get performance benefits from storing a points as columns/ rows of a matrix of concrete types (i.e. Matrix{Float64}). Where do we put the partial "annotation" of a DiffPt? One option could be to define new types and a load of convenience functions to make it seamless to combine them with existing ones.
That is something I am currently thinking about a lot. I would think that custom composite types would be a good idea. Storing
(x, 2)
(x, n),
(y,1),
...
(y,n)
could be replaced and emulated by some sort of dictionary
x => [2, n]
y => 1:n
the advantage is, that you could specialize on index ranges to take more than one partial derivative (and use backwarddiff to get the entire gradient).
But you would still need the ability to interleave points
(x,1)
(y,2)
(x,2)
and I am not yet sure how to fix the abstract order of the points.
Basically what should probably happen is something akin to an SQL join:
ID | PosID | Partial1 | Partial2 |
---|---|---|---|
1 | 1 | NULL | NULL |
2 | 1 | 1 | 2 |
3 | 2 | 2 | NULL |
... |
ID | Coord1 | Coord2 | Coord3 |
---|---|---|---|
1 | 0.04 | 1.34 | 2.6 |
2 | 42.7 | 1.0 | 3.4 |
3 | 2.1 | 0.3 | 4.5 |
... |
A left join on (Entries, Postions) would then result in the theoretical list
[
DiffPt(pos1, ()),
DiffPt(pos1, (1,2)),
DiffPt(pos2, (2,)),
...
]
But now I don't have the the ranges yet...
Summary
This is a minimal implementation to enable the simulation of gradients (and higher order derivatives) of GPs (see also https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/504)
Proposed changes
For a covariance kernel k of GP Z, i.e.
a
DiffPt
allows the differentiation of Z, i.e.for higher order derivatives partial can be any iterable, i.e.
the code for this feature is extremely minimal but allows the simulation of arbitrary derivatives of Gaussian Processes. It only contains
DiffPt
_evaluate(::T, x::DiffPt, y::DiffPt) where {T<: Kernel}
function which callspartial
functions that take the derivatives.What alternatives have you considered?
This is the implementation with the smallest footprint but not the most performant. What essentially happens here is the simulation of the multivariate GP $f = (Z, \nabla Z)$ which is a $d+1$ dimensional GP if $Z$ is a univariate GP with input dimension $d$. Due to the "no multi-variate kernels" design philosophy of KernelFunctions.jl we are forced to calculate the entries of the covariance matrix one-by-one. It would be more performant to calculate the entire matrix in one go using backward diff for the first pass and forward diff for the second derivative.
It might be possible to somehow specialize on ranges to get back this performance. But it is not completely clear how. Since we do not call
which could easily be caught by specializing on
broadcast
but in reality we do something likeAnd this is still not quite true as we consider all pairs of these lists and not just a zip.
Breaking changes
None.