JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Introduce DiffPt for the covariance function of derivatives #508

Open FelixBenning opened 1 year ago

FelixBenning commented 1 year ago

Summary

This is a minimal implementation to enable the simulation of gradients (and higher order derivatives) of GPs (see also https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/504)

Proposed changes

For a covariance kernel k of GP Z, i.e.

k(x,y) # = Cov(Z(x), Z(y)),

a DiffPt allows the differentiation of Z, i.e.

k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y))

for higher order derivatives partial can be any iterable, i.e.

k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))

the code for this feature is extremely minimal but allows the simulation of arbitrary derivatives of Gaussian Processes. It only contains

What alternatives have you considered?

This is the implementation with the smallest footprint but not the most performant. What essentially happens here is the simulation of the multivariate GP $f = (Z, \nabla Z)$ which is a $d+1$ dimensional GP if $Z$ is a univariate GP with input dimension $d$. Due to the "no multi-variate kernels" design philosophy of KernelFunctions.jl we are forced to calculate the entries of the covariance matrix one-by-one. It would be more performant to calculate the entire matrix in one go using backward diff for the first pass and forward diff for the second derivative.

It might be possible to somehow specialize on ranges to get back this performance. But it is not completely clear how. Since we do not call

k.(1:d, 1:d)

which could easily be caught by specializing on broadcast but in reality we do something like

k.([(x, 1),...(x,d)], [(y,1),...(y,d)])

And this is still not quite true as we consider all pairs of these lists and not just a zip.

Breaking changes

None.

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage has no change and project coverage change: -16.75 :warning:

Comparison is base (ef6d459) 94.16% compared to head (deebf0c) 77.41%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #508 +/- ## =========================================== - Coverage 94.16% 77.41% -16.75% =========================================== Files 52 54 +2 Lines 1387 1430 +43 =========================================== - Hits 1306 1107 -199 - Misses 81 323 +242 ``` | [Impacted Files](https://app.codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/508?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses) | Coverage Δ | | |---|---|---| | [src/KernelFunctions.jl](https://app.codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/508?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL0tlcm5lbEZ1bmN0aW9ucy5qbA==) | `100.00% <ø> (ø)` | | | [src/diffKernel.jl](https://app.codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/508?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2RpZmZLZXJuZWwuamw=) | `0.00% <0.00%> (ø)` | | | [src/mokernels/differentiable.jl](https://app.codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/508?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL21va2VybmVscy9kaWZmZXJlbnRpYWJsZS5qbA==) | `0.00% <0.00%> (ø)` | | ... and [19 files with indirect coverage changes](https://app.codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/508/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

Crown421 commented 1 year ago

I don't think ForwardDiff should be an explicit dependency for KernelFunctions. To me this would make more sense as an extension (which might also allow for different implementations, i.e. Enzyme forward).

Crown421 commented 1 year ago

I have been playing around with the ideas in PR, and realized that to make this work there are some open questions.

The first issue is that kernelmatrix is specialized for SimpleKernels and MOKernels, which would require some additional thought and changes to make work. For SimpleKernels just adding the _evaluate method is insufficient, as kernelmatrix uses pairwise(metric(...), .... In principle one could think about going deeper, and start extending all necessary methods for Distances.jl, but at this point the question is whether it is worth it.

At this point a wrapper might be easier, because then the only thing needed are some additional methods. At this point GP(DiffWrapper(Kernel)) would indeed be a different object, since it might use be less specialized methods for the kernelmatrix.

Additionally, GP gd that expresses the derivative of some GP g is not quite the same object. At least for the exact GP posteriors defined in AbstractGPs.jl, each instance stores the Cholesky decomposition C as well as C\y for the undifferentiated input.kernel, which can then be efficiently re-used each time the mean or var are computed. To get the variance var(::GP, ::DiffPt) we can't use the existing C, so we would need to compute the whole C for d^2/dx1dx2 k(x1,x2) or store this matrix in addition.

FelixBenning commented 1 year ago

kernelmatrix

@Crown421 the kernelmatrix thing is an issue I have not considered. Taking derivatives breaks isotropy leaving only stationarity intact. But I am also not sure why this specialization for kernelmatrix

function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
    return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end

is more performant than

function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
    return broadcast(κ, x, permutedims(x))

I mean pairwise(metric(κ), x) = broadcast(metric(κ), x, permutedims(x)). So the specialized implementation does essentially

K = broadcast(x, permutedims(x) do (x,y)
      metric(κ)(x,y)
end # first pass over K
map( x->kappa(κ, x), K) # second pass over K

which accesses the elements of K twice. On the other hand

K = broadcast(x, permutedims(x) do (x,y)
      κ(x,y)
end

only requires one access. Since memory access is typically the bottleneck, the general definition should be more performant. That is unless

(κ::SimpleKernel)(x,y) = kappa(κ, metric(κ)(x,y))

is not inlined and causes more function calls. But in that case it probably makes more sense to force inline the code above with @inline this should ensure the general implementation would be reduced to

K = broadcast(x, permutedims(x) do (x,y)
      kappa(κ, metric(κ)(x,y))
end

by the compiler. Which should be faster than the two pass version.

Issues with a wrapper

While DiffWrapper(kernel) may be of type kernel its compositions are not obvious. I mean for sums it is fine, since sum and differentiation commute. But for a function transform you do not have

\frac{d}{d x_i} k(f(x), f(y)) = (\frac{\partial}{\partial x_i} k) (f(x), f(y))

So

DiffWrapper(kernel) ∘ FunctionTransform(f) != DiffWrapper(kernel ∘ FunctionTransform(f))

So if you wanted to treat k=DiffWrapper(SqExponentialKernel()) as "the" mathematical squared exponential kernel

k(x,y) = \exp\Bigl(\frac{(x-y)^2}{2}\Bigr)

which is differentiable, then you would expect the behavior of DiffWrapper(kernel ∘ FunctionTransform(f)). So you would have to specialize all the function composition for DiffWrapper. And that feels like the main selling point of KernelFunctions.jl to me. Of course you could tell people to only use DiffWrapper at the very end. But man is that ugly for zero reason.

Caching the cholesky decomposition

I do not understand this point. This would automatically happen with this implementation too. I mean DiffPt(x, partial=i) is just a special point $x +\partial_i$ which is not in $\mathbb{R}^n$. But it still has an evaluation $y$ and a row in the cholesky matrix which can be cached. Everything should just work as is with AbstractGP.jl

FelixBenning commented 1 year ago

This is really weird...

julia> @btime kernelmatrix(k, Xc);
  23.836 ms (5 allocations: 61.05 MiB)

julia> @btime map(x -> KernelFunctions.kappa(k, x), KernelFunctions.pairwise(KernelFunctions.metric(k), Xc));
  103.780 ms (8000009 allocations: 183.12 MiB)

julia> @btime k.(Xc, permutedims(Xc));
  78.818 ms (4 allocations: 30.52 MiB)

julia> size(Xc)
(2000,)

julia> size(first(Xc))
(2,)

julia> k
Squared Exponential Kernel (metric = Distances.Euclidean(0.0))

Why is the performance of the implementation

function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
    return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/matrix/kernelmatrix.jl#L149-L151 worse than the function itself?

devmotion commented 1 year ago

worse than the function itself?

Not sure what function you mean here and what you expect to be worse/better.

The main issue is that your benchmarking is flawed, variables etc. have to be interpolated since otherwise you suffer, sometimes massively, from type instabilities and inference issues introduced by global variables.

So instead you should perform benchmarks such as

julia> using KernelFunctions, BenchmarkTools

julia> Xc = ColVecs(randn(2, 2000));

julia> k = GaussianKernel();

julia> @btime kernelmatrix($k, $Xc);
  38.594 ms (5 allocations: 61.05 MiB)

julia> @btime kernelmatrix($k, $Xc);
  35.585 ms (5 allocations: 61.05 MiB)

julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
  37.478 ms (5 allocations: 61.05 MiB)

julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
  33.321 ms (5 allocations: 61.05 MiB)

julia> @btime $k.($Xc, permutedims($Xc));
  45.019 ms (2 allocations: 30.52 MiB)

julia> @btime $k.($Xc, permutedims($Xc));
  45.339 ms (2 allocations: 30.52 MiB)

I mean pairwise(metric(κ), x) = broadcast(metric(κ), x, permutedims(x)).

No, not generally. pairwise for standard distances (such as Euclidean) is implemented in highly optimized ways in Distances (e.g., by exploiting and ensuring symmetry of the distance matrix).

Since memory access is typically the bottleneck, the general definition should be more performant.

Therefore this statement also does not hold generally. If you are concerned about memory allocations, probably you also might want to use kernelmatrix! instead of kernelmatrix which minimizes allocations:

julia> # Continued from above

julia> K = Matrix{Float64}(undef, length(Xc), length(Xc));

julia> @btime kernelmatrix!($K, $k, $Xc);
  25.775 ms (1 allocation: 15.75 KiB)

julia> @btime kernelmatrix!($K, $k, $Xc);
  30.012 ms (1 allocation: 15.75 KiB)

Another disadvantage of broadcasting is that generally it means more work for the compiler (the whole broadcasting machinery is very involved and complicated) and hence increases compilation times.

FelixBenning commented 1 year ago

@devmotion ahh 🤦 I only looked into the distances/pairwise.jl file for the pairwise function. I did not know that Distances.jl defines this as well. This is why I hate julias using import mechanism and the include instead of import of files. You never know where functions are coming from. It is basically like from module import * in python which everyone dislikes for the same reason.

I guess if pairwise actually uses the symmetry of distances, then I see where the speedup in the isotropic case comes from.

devmotion commented 1 year ago

Yes, I try to avoid using XX in packages nowadays and rather use import XX or using XX: f, g, h to make such relations clearer (still convenient to use using XX in the REPL IMO).

Crown421 commented 1 year ago

Caching the cholesky decomposition

I do not understand this point. This would automatically happen with this implementation too. I mean DiffPt(x, partial=i) is just a special point x+∂i which is not in Rn. But it still has an evaluation y and a row in the cholesky matrix which can be cached. Everything should just work as is with AbstractGP.jl

My apologies, I had an error in thinking here, I was convinced that an additional matrix would need to be cached, not sure why.

Issues with a wrapper

While DiffWrapper(kernel) may be of type kernel its compositions are not obvious. I mean for sums it is fine, since sum and differentiation commute. But for a function transform you do not have

So

DiffWrapper(kernel) ∘ FunctionTransform(f) != DiffWrapper(kernel ∘ FunctionTransform(f))

So if you wanted to treat k=DiffWrapper(SqExponentialKernel()) as "the" mathematical squared exponential kernel

which is differentiable, then you would expect the behavior of DiffWrapper(kernel ∘ FunctionTransform(f)). So you would have to specialize all the function composition for DiffWrapper. And that feels like the main selling point of KernelFunctions.jl to me. Of course you could tell people to only use DiffWrapper at the very end. But man is that ugly for zero reason.

Well, not zero reason. There are multiple reasons for using a wrapper in this PR, and therefore it comes down to opinion. It would be easy to define some fallback functions that throw an error in problematic cases, advising users to use the wrapper at the end.

Given that differentiable kernels would not be a core feature, but rather an Extension when also loading a compatible autodiff package, any changes in main part of KernelFunctions should be minimal, and not reduce any performance.

Therefore I would personally prefer starting with a wrapper, at least for now, to have the key functionality available and see additional issues during use. For example, I have already wondered:

  1. How DiffPt should be treated in combination with "normal" points. You mention above mixing the two, but what should for example vcat(ColVecs(X), DiffPt(x, partial=1)) look like? We get performance benefits from storing a points as columns/ rows of a matrix of concrete types (i.e. Matrix{Float64}). Where do we put the partial "annotation" of a DiffPt? One option could be to define new types and a load of convenience functions to make it seamless to combine them with existing ones.

  2. How does DiffPt combine with MOInputs?

  3. For MOPinputs there is a prepare_isotopic_multi_output_data method, should there be something similar for DiffPts?

For me these are important usability questions, with a much higher "ugliness" potential than where one can put a wrapper. During a normal session, I manipulate a lot of inputs and input collections, but only define a GP/ kernel once.

FelixBenning commented 1 year ago

@Crown421

Well, not zero reason. There are multiple reasons for using a wrapper in this PR, and therefore it comes down to opinion. It would be easy to define some fallback functions that throw an error in problematic cases, advising users to use the wrapper at the end.

I am starting to agree, given that I can not come up with a good solution to the kernelmatrix problem at the moment.

How DiffPt should be treated in combination with "normal" points. You mention above mixing the two, but what should for example vcat(ColVecs(X), DiffPt(x, partial=1)) look like? We get performance benefits from storing a points as columns/ rows of a matrix of concrete types (i.e. Matrix{Float64}). Where do we put the partial "annotation" of a DiffPt? One option could be to define new types and a load of convenience functions to make it seamless to combine them with existing ones.

That is something I am currently thinking about a lot. I would think that custom composite types would be a good idea. Storing

(x, 2)
(x, n),
(y,1),
...
(y,n)

could be replaced and emulated by some sort of dictionary

x => [2, n]
y => 1:n

the advantage is, that you could specialize on index ranges to take more than one partial derivative (and use backwarddiff to get the entire gradient).

But you would still need the ability to interleave points

(x,1)
(y,2)
(x,2)

and I am not yet sure how to fix the abstract order of the points.

Basically what should probably happen is something akin to an SQL join:

TABLE: Enries

ID PosID Partial1 Partial2
1 1 NULL NULL
2 1 1 2
3 2 2 NULL
...

TABLE: Positions

ID Coord1 Coord2 Coord3
1 0.04 1.34 2.6
2 42.7 1.0 3.4
3 2.1 0.3 4.5
...

A left join on (Entries, Postions) would then result in the theoretical list

[
    DiffPt(pos1, ()),
    DiffPt(pos1, (1,2)),
    DiffPt(pos2, (2,)),
     ...
]

But now I don't have the the ranges yet...