JuliaStats / MixedModels.jl

A Julia package for fitting (statistical) mixed-effects models
http://juliastats.org/MixedModels.jl/stable
MIT License
405 stars 48 forks source link

Use init_from_lmm as default in fit! method for GLMM #750

Open dmbates opened 7 months ago

dmbates commented 7 months ago

I have been much more successful when using init_from_lmm when fitting GLMMs with PRIMA.bobyqa but have only tried it with Bernoulli responses. Shall we make it the default?

dmbates commented 7 months ago

The "goldstein" test set in test/pirls.jl is an example where init_from_lmm provides much different, and superior (in the sense of non-trivially smaller deviance), parameter estimates.

julia> using MixedModels, PooledArrays

julia> goldstein = (
           group = PooledArray(repeat(string.('A':'J'), outer=10)),
           y = [
               83, 3, 8, 78, 901, 21, 4, 1, 1, 39,
               82, 3, 2, 82, 874, 18, 5, 1, 3, 50,
               87, 7, 3, 67, 914, 18, 0, 1, 1, 38,
               86, 13, 5, 65, 913, 13, 2, 0, 0, 48,
               90, 5, 5, 71, 886, 19, 3, 0, 2, 32,
               96, 1, 1, 87, 860, 21, 3, 0, 1, 54,
               83, 2, 4, 70, 874, 19, 5, 0, 4, 36,
               100, 11, 3, 71, 950, 21, 6, 0, 1, 40,
               89, 5, 5, 73, 859, 29, 3, 0, 2, 38,
               78, 13, 6, 100, 852, 24, 5, 0, 1, 39
               ],
           );

julia> gform = @formula(y ~ 1 + (1|group));

julia> m1 = GeneralizedLinearMixedModel(gform, goldstein, Poisson());

julia> fit!(m1).optsum                   # using default starting values
Initial parameter vector: [4.727210823648169, 1.0]
Initial objective value:  246.12019299931663

Optimizer (from NLopt):   LN_BOBYQA
Lower bounds:             [-Inf, 0.0]
ftol_rel:                 1.0e-12
ftol_abs:                 1.0e-8
xtol_rel:                 0.0
xtol_abs:                 [1.0e-10]
initial_step:             [4.727210823648169, 0.75]
maxfeval:                 -1
maxtime:                  -1.0

Function evaluations:     31
Final parameter vector:   [4.1921964390775655, 1.83824520173986]
Final objective value:    193.55873023848017
Return code:              FTOL_REACHED

julia> m1 = GeneralizedLinearMixedModel(gform, goldstein, Poisson());

julia> fit!(m1; init_from_lmm=Set((:β,:θ))).optsum
Initial parameter vector: [4.727210811908346, 2.22268824817508]
Initial objective value:  194.70703488515588

Optimizer (from NLopt):   LN_BOBYQA
Lower bounds:             [-Inf, 0.0]
ftol_rel:                 1.0e-12
ftol_abs:                 1.0e-8
xtol_rel:                 0.0
xtol_abs:                 [1.0e-10]
initial_step:             [4.727210811908346, 1.6670161861313098]
maxfeval:                 -1
maxtime:                  -1.0

Function evaluations:     31
Final parameter vector:   [4.2195828270624585, 3.847141752689735]
Final objective value:    192.001698988138
Return code:              FTOL_REACHED

This is simulated, and probably unrealistic, data with a very large standard deviation for the random effects. The average response in group 5 is over 900 whereas the average response in group 8 is less than 1.

Nevertheless this is an example where the LMM-based initial values are non-trivially better.

dmbates commented 7 months ago

When I initially looked at this I thought that the starting values were from an unweighted LMM fit to the same formula/data combination as the GLMM, which made me wonder what happened with the scale parameter of the Gaussian distribution for the LMM. On more careful examination I found that the LMM uses the weights from the GLM fit, so the starting values for theta would be in the right range.

palday commented 7 months ago

IIRC the grouseticks example has traditionally been hard to fit -- does it perform better with init_from_lmm?

dmbates commented 6 months ago

You have to scroll to the right to see the calls to fit but the bottom line is that without init_from_lmm the fit takes about 3.6 seconds and with it about 1.6 seconds.

julia> @b (dataset(:grouseticks), @formula(ticks ~ 1+year+height+ (1|index) + (1|brood) + (1|location)), Poisson()) fit(MixedModel, _[2], first(_), last(_); contrasts, progress=false) seconds=10
3.605 s (387416 allocs: 134.871 MiB, 0.12% gc time)

julia> @b (dataset(:grouseticks), @formula(ticks ~ 1+year+height+ (1|index) + (1|brood) + (1|location)), Poisson()) fit(MixedModel, _[2], first(_), last(_); contrasts, progress=false) seconds=15
3.605 s (387416 allocs: 134.871 MiB, 0.13% gc time)

julia> @b (dataset(:grouseticks), @formula(ticks ~ 1+year+height+ (1|index) + (1|brood) + (1|location)), Poisson()) fit(MixedModel, _[2], first(_), last(_); contrasts, progress=false, fast=true) seconds=15
649.160 ms (73875 allocs: 24.894 MiB)

julia> @b (dataset(:grouseticks), @formula(ticks ~ 1+year+height+ (1|index) + (1|brood) + (1|location)), Poisson()) fit(MixedModel, _[2], first(_), last(_); contrasts, progress=false, init_from_lmm=(:β, :θ)) seconds=15
1.681 s (179238 allocs: 63.296 MiB)
dmbates commented 6 months ago

Also a 4-fold increase in speed for init_from_lmm on the verbagg example.

julia> @b fit(MixedModel, @formula(r2 ~ 1 + anger + gender + btype + situ + (1|subj) + (1|item)), dataset(:verbagg), Bernoulli(); contrasts, progress=false)  seconds=5
2.368 s (436902 allocs: 21.895 MiB)

julia> @b fit(MixedModel, @formula(r2 ~ 1 + anger + gender + btype + situ + (1|subj) + (1|item)), dataset(:verbagg), Bernoulli(); contrasts, init_from_lmm=(:β, :θ), progress=false)  seconds=5
624.775 ms (150194 allocs: 11.311 MiB)
palday commented 6 months ago

Do we have an intuition about when things would be worse? Maybe in a model with random slopes? In that case, the fit time for the LMM might be nontrivial.

palday commented 6 months ago

(FWIW I'm increasingly in favor of shifting the default, but want to make sure we have some intuitions about potential failure modes.)

dmbates commented 6 months ago

Here's a collection of model fits, including one with vector-valued random effects.

julia> runglbmk(gltbl; init_from_lmm=(:β, :θ))
Table with 4 columns and 4 rows:
     bmk                                            dsnm         dist                       frm
   ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 1 │ Sample(time=0.191018, allocs=97529, bytes=56…  contra       Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + livch…
 2 │ Sample(time=0.187686, allocs=79237, bytes=45…  contra       Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + :(((≠…
 3 │ Sample(time=0.627815, allocs=150259, bytes=1…  verbagg      Bernoulli{Float64}(p=0.5)  r2 ~ 1 + anger + gender + btype + situ + :(1…
 4 │ Sample(time=1.6794, allocs=179262, bytes=663…  grouseticks  Poisson{Float64}(λ=1.0)    ticks ~ 1 + year + height + :(1 | index) + :…

julia> runglbmk(gltbl)
Table with 4 columns and 4 rows:
     bmk                                            dsnm         dist                       frm
   ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 1 │ Sample(time=0.122951, allocs=65664, bytes=42…  contra       Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + livch…
 2 │ Sample(time=0.201536, allocs=83628, bytes=47…  contra       Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + :(((≠…
 3 │ Sample(time=2.3971, allocs=436827, bytes=229…  verbagg      Bernoulli{Float64}(p=0.5)  r2 ~ 1 + anger + gender + btype + situ + :(1…
 4 │ Sample(time=3.59434, allocs=387440, bytes=14…  grouseticks  Poisson{Float64}(λ=1.0)    ticks ~ 1 + year + height + :(1 | index) + :…

julia> gltbl
Table with 4 columns and 4 rows:
     dsnm         secs  dist                       frm
   ┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 1 │ contra       2.0   Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + livch + :(1 | urban & dist)
 2 │ contra       2.0   Bernoulli{Float64}(p=0.5)  use ~ 1 + age + :(abs2(age)) + urban + :(((≠)("0"))(livch)) + :((1 + urban) | dist)
 3 │ verbagg      15.0  Bernoulli{Float64}(p=0.5)  r2 ~ 1 + anger + gender + btype + situ + :(1 | subj) + :(1 | item)
 4 │ grouseticks  15.0  Poisson{Float64}(λ=1.0)    ticks ~ 1 + year + height + :(1 | index) + :(1 | brood) + :(1 | location)

I have been thinking of trying to fit a model to the correct/incorrect response in the EnglishLexicon data, which would be a stress test.

dmbates commented 6 months ago

code is in bench/runbenchmarks.jl in the db/RegressionTests branch