Closed rossviljoen closed 3 years ago
Minibatching & ADAM (possibly with Flux)
This is definitely something that we need, although we possibly we don't want it to be Flux specific? Does Flux have particular minibatching helpers or something?
GPU support?
This is a good idea, but I would probably put it below getting a basic implementation + abstractions sorted. Possibly the main issue here will be ensuring that we have at least one kernel that plays nicely with this, so you might need to make a small PR to kernel functions or something to sort that out.
Add support for non-conjugate likelihoods (this is done in GPFlow and [1] by quadrature)
I suspect we're going to need to support both quadrature and Monte Carlo approaches here. As @theogf mentioned at the last meeting, although you often have low-dimensional integrals in the reconstruction term, it's not uncommon to have to work with quite high dimensional integrals (e.g. multi-class classification). In those cases, you hit the curse of dimensionality and cubature will tend not to be a particularly fantastic option. That being said, if you can get away with quadrature in a particular problem, it's typically a very good idea.
This is definitely something that we need, although we possibly we don't want it to be Flux specific? Does Flux have particular minibatching helpers or something?
Regarding this, Flux does have minibatch helpers via the DataLoader
structure and its optimisers are quite practical. That said, this is a VERY heavy dependency and minibatch helpers can probably be found somewhere else. For optimisers there is a current work to take them out of Flux but this is taking forever : https://github.com/FluxML/Optimisers.jl
Possibly the main issue here will be ensuring that we have at least one kernel that plays nicely with this, so you might need to make a small PR to kernel functions or something to sort that out.
I tried some things already and on the kernel functions side the only issue are kernels without the right constructors see https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/299
I suspect we're going to need to support both quadrature and Monte Carlo approaches here.
That's true, but it's probably wiser to just start with quadrature for now, and let the API be general enough such that adding MC integration would not be a burden
That's true, but it's probably wiser to just start with quadrature for now, and let the API be general enough such that adding MC integration would not be a burden
Why do you think that this is the case? (Not arguing against it, just curious to understand your reasoning -- I would have imagined that Monte Carlo would be more straightforward to implement)
I think it's more performance related, quadrature works much better than sampling for those 1-D or 2-D integrals Also using packages like FastGaussQuadrature makes it super easy.
I'd be interested to know how you've gone about making a AD-friendly quadrature algorithms. I wound up writing an rrule directly here (which I should update to use the new ChainRules style stuff now that it can call back into AD), in my ConjugateComputationVI package.
I used Opper and Archambeau 2009, so I just use quadrature on the gradient (and the hessian), here : https://github.com/theogf/AugmentedGaussianProcesses.jl/blob/c7c9e9cf25a278b0855e769ad943d724513df36d/src/inference/quadratureVI.jl#L181 Alternatively, I think this is differentiable : https://github.com/theogf/AugmentedGaussianProcesses.jl/blob/c7c9e9cf25a278b0855e769ad943d724513df36d/src/inference/quadratureVI.jl#L163
I used Opper and Archambeau 2009
As in the O(2N) variational parameters parametrisation, or the tricks to compute the gradient w.r.t. the parameters by re-writing them as expectations of gradients / hessians?
Well both :)
Cool. Regular gradients or natural?
Haha nice. So am I correct in the understanding that my CVI implementation should be basically equivalent to your natural gradient implementation here? (Since CVI is just natural gradients)
Hmm I am not completely sure, there might be some marginal differences... I derived the natural scheme some time ago and forgot what I did exactly...
Hmmm I'd be interested to know. Probably we should chat about this at some point.
That's true, but it's probably wiser to just start with quadrature for now, and let the API be general enough such that adding MC integration would not be a burden
For this, I imagine I'd want to use GPLikelihoods.jl
? (although it doesn't seem to be registered yet unless I'm missing something).
This is definitely something that we need, although we possibly we don't want it to be Flux specific? Does Flux have particular minibatching helpers or something?
Sure - I wasn't intending to have Flux as a dependency (beyond Functors.jl
perhaps), just make sure it could integrate reasonably easily.
I was thinking of defining something like the Flux layer @devmotion was talking about in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/299 which just exposes the parameters and a function to build the model (i.e. something like what's currently in the example) but I don't know if it's better to use ParameterHandling instead?
Regarding this, Flux does have minibatch helpers via the
DataLoader
structure and its optimisers are quite practical.
It looks like https://github.com/JuliaML/MLDataPattern.jl has some minibatch helpers which could work instead
Regarding quadrature algorithms: I'd recommend https://github.com/SciML/Quadrature.jl, it provides a unified interface for many different quadrature packages and is fully differentiable.
Regarding the Flux layer: I think one should be able to just specify a function that creates a kernel and a vector of parameters, i.e., ParamsKernel(f, params)
. Then this could be used with ParameterHandling (just call ParamsKernel(reverse(ParameterHandling.flatten(kernel))...)
) or Functors (call ParamsKernel(reverse(Functors.functor(kernel))...)
). If necessary, one could provide convenience functions that allow users to skip the reverse
.
Regarding quadrature algorithms: I'd recommend https://github.com/SciML/Quadrature.jl, it provides a unified interface for many different quadrature packages and is fully differentiable.
IIRC I tried Quadrature.jl and couldn't get it to work in the GP use-case. Firstly, I don't think that it supports Gauss-Hermite quadrature (which is really what you want to use). Unfortunately I can't remember what my other issue with it was.
TBH Quadrature.jl is a bit an overkill for GPs... All you really need is FastGaussQuadrature.jl
The main disadvantage is that FastGaussQuadrature does not provide any error estimates and is not adaptive. But maybe this does not matter here (much)?
But maybe this does not matter here (much)?
My experience (with simple likelihoods) has been that this is indeed the case. Not sure where this starts to be an issue though.
Firstly, I don't think that it supports Gauss-Hermite quadrature (which is really what you want to use).
They also have QuadGK as a backend: https://github.com/JuliaMath/QuadGK.jl
Isn't that just Gaussian quadrature, rather than Gauss-Hermite?
Yes, it uses adaptive Gauss–Kronrod quadrature.
Side note/suggestion- if the discussion continues any further, it might make it easier to follow by separating it into individual issues (e.g. one for quadrature, one for minibatching):)
Praise be the mono-issue!
Good point! I've opened #3 and #4 so far.
Everything discussed here is either done or in separate issues (#15), so I think it's safe to close?
The current plan/potential things to do includes:
Natural gradients [2][1] Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable variational Gaussian process classification." Artificial Intelligence and Statistics. PMLR, 2015. [2] Salimbeni, Hugh, Stefanos Eleftheriadis, and James Hensman. "Natural gradients in practice: Non-conjugate variational inference in gaussian process models." International Conference on Artificial Intelligence and Statistics. PMLR, 2018.