JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

ChainTransform AD performance #466

Closed willtebbutt closed 2 years ago

willtebbutt commented 2 years ago

Summary

The ChainTransform has some performance issues on master.

Evidence:

using BenchmarkTools, KernelFunctions, Zygote

kernel(θ) = with_lengthscale(Matern12Kernel(), 0.5) ∘ PeriodicTransform(θ)

foo(x) = KernelFunctions._map(PeriodicTransform(1 / 5), x)
bar(θ, x) = kernelmatrix(kernel(θ), x)

const x = randn(500);
out, pb = Zygote.pullback(bar, 5.0, x);

Δ = copy(out);
@benchmark $pb($Δ)

master:

BenchmarkTools.Trial: 36 samples with 1 evaluation.
 Range (min … max):  118.074 ms … 234.881 ms  ┊ GC (min … max): 18.20% … 24.77%
 Time  (median):     140.383 ms               ┊ GC (median):    18.11%
 Time  (mean ± σ):   140.637 ms ±  21.592 ms  ┊ GC (mean ± σ):  20.45% ±  4.18%

  █ █▃█ ▃     ██  ▃ ▃
  █▇███▇█▇▇▁▇▁██▇▇█▇█▇▁▇▇▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  118 ms           Histogram: frequency by time          235 ms <

 Memory estimate: 128.37 MiB, allocs estimate: 1586367.

This branch:

BenchmarkTools.Trial: 1946 samples with 1 evaluation.
 Range (min … max):  1.979 ms … 10.005 ms  ┊ GC (min … max): 0.00% … 77.81%
 Time  (median):     2.236 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.562 ms ±  1.175 ms  ┊ GC (mean ± σ):  8.04% ± 12.87%

  ██▇▆▆▅▃▃▁                                                  ▁
  ███████████▇▄▆▄▅▄▁▁▁▇▇▄▄▄▄▁▄▄▁▁▁▁▄▄▄▁▄▄▅▄▆▄▄▁▄▅▄▆▇▄▅▅▅▅▄▄▄ █
  1.98 ms      Histogram: log(frequency) by time     8.98 ms <

 Memory estimate: 7.70 MiB, allocs estimate: 272.

Proposed changes

  1. use a tuple rather than a vector to contain the things being chained together. This enables type-stable composition.
  2. call _map rather than map, because that's the API

Note that the way I'm testing that this change has been successful is by checking that the number of allocations required to compute the kernelmatrix, its forwards-pass and pullback (using Zygote) is invariant to the size of input vector considered. I plan to roll this out more widely in the coming days.

What alternatives have you considered?

None

Breaking changes

This only widens the set of permissible types in the ChainTransform, and which one gets used by default. On the basis of this, my inclination is to suggest that we shouldn't consider this breaking, but I might have missed something obvious.

codecov[bot] commented 2 years ago

Codecov Report

Merging #466 (aef58b9) into master (b5af459) will decrease coverage by 0.07%. The diff coverage is 66.66%.

@@            Coverage Diff             @@
##           master     #466      +/-   ##
==========================================
- Coverage   93.16%   93.09%   -0.08%     
==========================================
  Files          52       52              
  Lines        1259     1275      +16     
==========================================
+ Hits         1173     1187      +14     
- Misses         86       88       +2     
Impacted Files Coverage Δ
src/transform/chaintransform.jl 80.00% <66.66%> (+1.73%) :arrow_up:
src/matrix/kernelpdmat.jl 75.00% <0.00%> (-6.82%) :arrow_down:
src/kernels/normalizedkernel.jl 80.00% <0.00%> (-2.36%) :arrow_down:
src/mokernels/lmm.jl 100.00% <0.00%> (ø)
src/kernels/kernelsum.jl 100.00% <0.00%> (ø)
src/kernels/kernelproduct.jl 100.00% <0.00%> (ø)
src/kernels/kerneltensorproduct.jl 98.85% <0.00%> (+0.08%) :arrow_up:
src/approximations/nystrom.jl 92.68% <0.00%> (+0.18%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

willtebbutt commented 2 years ago

@theogf let me know whether my explanation of the testing is sufficient, and I'll add a docstring + merge

willtebbutt commented 2 years ago

Will squash + merge when CI passes