leeper / prediction

Tidy, Type-Safe 'prediction()' Methods
89 stars 14 forks source link

Add variances of predictions? #17

Open leeper opened 7 years ago

leeper commented 7 years ago

Now that prediction() has gained the at argument, it seems logical to report variances (ala Stata margins::margins()).

leeper commented 6 years ago

Relevant example (from email), in Stata:

* Low Birth Weight 
webuse lbw

* The logistic regression
logit low i.smoke i.race age lwt i.smoke ptl ht ui

* The margins proportions among smoke levels
* Note. There is no dy/dx because I want the margina mean (proportion)
margins smoke, post

* Now we use the proportions to compute the ratios
nlcom _b[1.smoke]/_b[0.smoke]

* Store the log-ratio
nlcom (logpr: ln(_b[1.smoke]) - ln(_b[0.smoke])), post

* Use lincom to get back to ratio world
lincom logpr, eform

R equivalent:


m <- glm(low ~ smoke + race + age + lwt + smoke + ptl + ht + ui, data = lbw, family = binomial)

# marginal proportions
prediction(m, at = list(smoke = 0:1))

but need the variances to do an R equivalent of nlcom.

randomgambit commented 6 years ago

@leeper looking forward to this! would be amazing!!

leeper commented 6 years ago

Here's a simple Stata example:

. regress y i.sex age distance
. margins sex
Predictive margins                              Number of obs     =      3,000
Model VCE    : OLS

Expression   : Linear prediction, predict()

             |            Delta-method
             |     Margin   Std. Err.      t    P>|t|     [95% Conf. Interval]
         sex |
       male  |   62.73048   .5354111   117.16   0.000     61.68067    63.78029
     female  |   76.71801   .5346591   143.49   0.000     75.66967    77.76634

and the equivalent R code we'll need to generalize:

# setup data
Y <- cbind(1, margex[, c("sex", "age", "distance", "y")])
Y$sex <- as.numeric(Y$sex)
Y$age <- as.numeric(Y$age)
Y$distance <- as.numeric(Y$distance)
Y$y <- as.numeric(Y$y)

# estimate model
m <- lm(y ~ factor(sex) + age + distance, data = Y)

# convert data to matrix
Y$y <- NULL
Ymat <- t(matrix(colMeans(Y)))

## standard error of predictive marginal mean (non-sense given sex is factor)
sqrt(diag(Ymat %*% vcov(m) %*% t(Ymat)))

# predictive margin for men
Ymale <- Ymat
Ymale[,2] <- 0
as.vector(Ymale %*% coef(m))
## [1] 62.73048
sqrt(diag(Ymale %*% vcov(m) %*% t(Ymale)))
## [1] 0.5354111

# predictive margin for women
Yfemale <- Ymat
Yfemale[,2] <- 1
as.vector(Yfemale %*% coef(m))
## [1] 76.71801
sqrt(diag(Yfemale %*% vcov(m) %*% t(Yfemale)))
## [1] 0.5346591
leeper commented 6 years ago

We can actually obtain this easier than I expected because it's the same as:

> predict(m, setNames(data.frame(Ymale), c("intercept", "sex", "age", "distance")), se.fit = TRUE)

[1] 0.5354111

[1] 2996

[1] 20.15342

> predict(m, setNames(data.frame(Yfemale), c("intercept", "sex", "age", "distance")), se.fit = TRUE)

[1] 0.5346591

[1] 2996

[1] 20.15342


> Ymale_df <- Y
> Ymale_df$sex <- 0
> mean(predict(m, Ymale_df))
[1] 62.73048
> Yfemale_df <- Y
> Yfemale_df$sex <- 1
> mean(predict(m, Yfemale_df))
[1] 76.71801
leeper commented 5 years ago

Need to match output of margins:

> summary(margins::margins(mod))
 factor     AME     SE       z      p   lower   upper
 group2  8.9454 0.9710  9.2129 0.0000  7.0423 10.8485
 group3 18.5723 1.5032 12.3552 0.0000 15.6261 21.5185
   sex1 18.4317 0.9599 19.2011 0.0000 16.5503 20.3132

So add columns:

leeper commented 5 years ago

This is now working for lm():

> summary(prediction(lm(mpg ~ cyl*am, mtcars), at = list(cyl = 4:8, am = 0:1)))
 at(cyl) at(am) Prediction     SE      z          p lower upper
       4      0      22.97 1.4840 15.478  4.854e-54 20.06 25.88
       5      0      20.99 1.1035 19.026  1.037e-80 18.83 23.16
       6      0      19.02 0.7971 23.862 7.665e-126 17.46 20.58
       7      0      17.04 0.6748 25.258 9.138e-141 15.72 18.37
       8      0      15.07 0.8232 18.304  7.718e-75 13.45 16.68
       4      1      27.93 1.0055 27.772 9.410e-170 25.95 29.90
       5      1      24.64 0.8163 30.190 3.234e-200 23.04 26.24
       6      1      21.36 0.9587 22.284 5.340e-110 19.48 23.24
       7      1      18.08 1.3302 13.594  4.324e-42 15.48 20.69
       8      1      14.80 1.7936  8.253  1.549e-16 11.29 18.32

But support for other model types will require some further work.

leeper commented 5 years ago

Stata Benchmarks:

. webuse margex
(Artificial data for margins)

. logit outcome i.sex i.group

Iteration 0:   log likelihood = -1366.0718  
Iteration 1:   log likelihood = -1207.4432  
Iteration 2:   log likelihood = -1186.5543  
Iteration 3:   log likelihood = -1185.6072  
Iteration 4:   log likelihood = -1185.6063  
Iteration 5:   log likelihood = -1185.6063  

Logistic regression                             Number of obs     =      3,000
                                                LR chi2(3)        =     360.93
                                                Prob > chi2       =     0.0000
Log likelihood = -1185.6063                     Pseudo R2         =     0.1321

     outcome |      Coef.   Std. Err.      z    P>|z|     [95% Conf. Interval]
         sex |
     female  |   .5074091   .1289294     3.94   0.000     .2547122     .760106
       group |
          2  |  -1.196351   .1251234    -9.56   0.000    -1.441589   -.9511138
          3  |  -2.591768   .2759998    -9.39   0.000    -3.132718   -2.050819
       _cons |  -1.199567   .1262273    -9.50   0.000    -1.446968   -.9521664

. margins

Predictive margins                              Number of obs     =      3,000
Model VCE    : OIM

Expression   : Pr(outcome), predict()

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |   .1696667   .0064564    26.28   0.000     .1570124    .1823209

. margins, predict(xb)

Predictive margins                              Number of obs     =      3,000
Model VCE    : OIM

Expression   : Linear prediction (log odds), predict(xb)

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |  -1.981424   .0731871   -27.07   0.000    -2.124868    -1.83798

leeper commented 5 years ago

Reference for all GLMs: http://indiana.edu/~jslsoc/stata/ci_computations/spost_deltaci.pdf

leeper commented 5 years ago

Further Stata benchmarks:

. quietly logit outcome distance

. margins

Predictive margins                              Number of obs     =      3,000
Model VCE    : OIM

Expression   : Pr(outcome), predict()

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |   .1696667   .0068061    24.93   0.000      .156327    .1830063

. margins, predict(xb)

Predictive margins                              Number of obs     =      3,000
Model VCE    : OIM

Expression   : Linear prediction (log odds), predict(xb)

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |  -1.706939   .0625907   -27.27   0.000    -1.829614   -1.584263

. help margins

. margins, atmeans

Adjusted predictions                            Number of obs     =      3,000
Model VCE    : OIM

Expression   : Pr(outcome), predict()
at           : distance        =    58.58566 (mean)

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |   .1535612   .0081355    18.88   0.000     .1376158    .1695066

. margins, predict(xb) atmeans

Adjusted predictions                            Number of obs     =      3,000
Model VCE    : OIM

Expression   : Linear prediction (log odds), predict(xb)
at           : distance        =    58.58566 (mean)

             |            Delta-method
             |     Margin   Std. Err.      z    P>|z|     [95% Conf. Interval]
       _cons |  -1.706939   .0625907   -27.27   0.000    -1.829614   -1.584263