JuliaStats / MixedModels.jl

A Julia package for fitting (statistical) mixed-effects models
http://juliastats.org/MixedModels.jl/stable
MIT License
407 stars 48 forks source link

Gradient of LMM objective through automatic differentiation #312

Open dmbates opened 4 years ago

dmbates commented 4 years ago

Copying here from Slack so that it doesn't disappear.

I added a notebook to https://github.com/RePsychLing/ZiFPsychLing/ showing the forward mode and sketching the reverse-mode automatic differentiation of the objective for linear mixed models. I believe that forward-mode, though more tedious, will be the way to go because it preserves the sparsity of the system to generate L.

I have been playing around a bit with some examples and the way that I think this can be done is to add a vector of BlockArray objects of the same block structure as the L field in the LinearMixedModel structure. The updateL! function would then need to be changed to update the forward-mode L-dot at the same time as L. From there, getting the gradient is straightforward.

I believe the only changes in other structures needed for this is to carry around a scratch matrix of the same size a lambda in an ReMat structure.

dmbates commented 4 years ago

A progress report: I started by creating a version that works on dense copies of A, L and just as a proof of concept. The type is called ToyModel in the forwarddiff branch. In a few tests it converges in fewer iterations than does the version without the gradient but often the change in time-per-iteration dominates the lower number of iterations. However, the emphasis at this point is in getting the gradient correct, not necessarily fast. When the gradient evaluation is incorporated into the standard LinearMixedModel representation, the ToyModel type will be dropped.

palday commented 4 years ago

Interesting. When we get to real benchmarking, I would be very curious whether how GLMM does with the gradient information.

dmbates commented 4 years ago

I don't think this approach will be useful for GLMMs. The objective being minimized for a LMM is a relatively simple function of the elements (just the diagonal elements, actually) of L so we can push the chain rule derivatives from θ through to L to the objective. For a GLMM things are much more complicated to get to the evaluation of the objective and the penalized least squares solution is just an intermediate step and the penalized sum of square residuals doesn't really give you information on the objective. It just tells you when you are at the conditional modes of the random effects so that you can evaluate the deviance residuals and hence the objective. And that is before Adaptive Gauss-Hermite Quadrature, etc.

palday commented 4 years ago

That unfortunately is what I was afraid of, once I actually started to think about it.