Open willtebbutt opened 1 year ago
Wasn't SnoopPrecompile designed for this precompilation scenario?
In SciML we discussed adding Preferences-based precompilation statements to allow users to switch on/off precompilation in a somewhat granular way, depending on their use cases.
In SciML we discussed adding Preferences-based precompilation statements to allow users to switch on/off precompilation in a somewhat granular way, depending on their use cases.
Excellent idea! I was just wondering about this. I've not looked too hard at SnoopPrecompile, so will have to do so.
I was wondering in particular about preferences-based precompilation for anything Zygote-related, but I've not been able to get Zygote to precompile properly yet. For example, I've tried adding this code block to AbstractGPs:
for k in kernels, x in xs
precompile(kernelmatrix, (typeof(k), typeof(x)))
precompile(kernelmatrix, (typeof(k), typeof(x), typeof(x)))
@show typeof(k), typeof(x), typeof(kernelmatrix)
(() -> Zygote._pullback(Zygote.Context(), KernelFunctions.kernelmatrix, k, x))()
# precompile(Zygote._pullback, (typeof(Zygote.Context()), typeof(kernelmatrix), typeof(k), typeof(x)))
# out, pb = Zygote._pullback(Zygote.Context(), kernelmatrix, k, x)
end
but when I run
using Pkg
pkg"activate ."
@time @eval using AbstractGPs, KernelFunctions, Random, LinearAlgebra, Zygote
@time @eval begin
X = randn(5, 25)
x = ColVecs(X)
f = GP(SEKernel())
fx = f(x, 0.1)
y = rand(fx)
logpdf(fx, y)
out, pb = Zygote._pullback(Zygote.Context(), KernelFunctions.kernelmatrix, SEKernel(), x)
# out, pb = Zygote._pullback(Zygote.Context(), logpdf, fx, y)
# Zygote.gradient(logpdf, fx, y)
end;
it still takes around 15s. I'm not entirely clear why this should be the case (e.g. I don't see how the method could be invalidated), but maybe I'm missing something obvious.
You may be aware of this by now, but I just wanted to note that in order to precompile the calls in other libraries you need to use SnoopPrecompile. I have been able to do so with massive benefits in TTFX when Zygote is involved.
You may be aware of this by now, but I just wanted to note that in order to precompile the calls in other libraries you need to use SnoopPrecompile. I have been able to do so with massive benefits in TTFX when Zygote is involved.
I was very much not aware of this (although maybe that's what David is alluding to above). Have you managed to e.g. get the entirety of the forwards- and reverse-passes to compile nicely?
With all of the exciting stuff happening with code caching in 1.9, I thought I'd take a look at our latency for some common tasks.
Consider the following code:
It times package load times, and 1st / 2nd evaluation times of some pretty standard AbstractGPs code. On 1.9, I see the following results:
Overall, this doesn't seem too bad.
However, we're not taking advantage of pre-compilation anywhere within the JuliaGPs ecosystem, so I wanted to know what would happen if we tried that. To this end, I added the following
precompile
statements in AbstractGPs:I've tried to add only pre-compile statements for low-level code that doesn't get involved in combinations of things. For example, I don't think it makes sense to add a pre-compile statement for
kernelmatrix
for a sum of kernels because you'd have to compile a separate method instance for each collection of pairs of kernel types that you ever encountered, and I want to avoid a combinatorial explosion._logpdf
,_rand_
, and_posterior_computations
are bits of code I've pulled out oflogpdf
,rand
and_posterior_computations
which areGP
-independent. i.e. they just depend on matrix types etc. This feels fair, because they don't need to be re-compiled for every new kernel that's used, just when the output ofkernelmatrix
isn't aMatrix{Float64}
or whatever.Anyway, the results are:
So it looks like by pre-compiling, we can get a really substantial 4x reduction in time-to-first-inference, or whatever we're calling it.
If you use the slightly more complicated kernel
you see (without pre-compilation):
With pre-compilation you see something like:
So here we see a similar performance boost because we've pre-compiled all of the code to compute the kernelmatrices for the
SEKernel
and theMatern32Kernel
, so the compiler only the code for kernelmatrix of their sum needs to be compiled on the fly.It does look like there's a small penalty paid in load time, but I think it might typically be outweighted substantially by the compilation savings.
I wonder whether there's a case for adding a for-loop to
KernelFunctions
that pre-compiles thekernelmatrix
,kernelmatrix_diag
etc methods for each "simple" kernel, where by "simple" I basically just mean anything that's not a composite kernel, and adding the kinds of method I've discussed above toAbstractGPs
. It might make the user experience substantially more pleasant 🤷 . I for one would love to have these savings.