leeper / margins

An R Port of Stata's 'margins' Command
https://cloud.r-project.org/package=margins
Other
260 stars 40 forks source link

Unit-level SE speedup with simple algebra trick #178

Closed vincentarelbundock closed 1 year ago

vincentarelbundock commented 3 years ago

Hi @leeper,

I think you saw on Twitter that I just released a clone of (hommage to?) margins to CRAN. I’ve been running some benchmarks and finding significant speed gains when unit_ses=TRUE. I’m not 100% sure, but I think there’s a very good chance that you could close that gap by implementing a super simple algebra trick.

The opportunity is that JVJ' requires computing a very wide matrix. But since we only need the diagonal to compute standard errors, we can use a computational trick to get 1000x speedup and much lighter memory footprint.

I’m not sure exactly where this could be implemented because I haven’t looked at your code base in a little while, but I thought I’d leave this here in case you or a contributor feels like looking into this.

To illustrate, I’ll do a similar calculation with model matrices instead of jacobians. First, let’s create a variance-covariance matrix and a large model matrix by stacking mtcars multiple times:

library(bench)

mod <- lm(mpg ~ factor(cyl) + hp + drat, mtcars)
mtbig <- do.call("rbind", lapply(1:500, \(x) mtcars))
mm <- model.matrix(mod, data = mtbig)
vcovmat <- vcov(mod)

Before running the benchmark, let’s convince ourselves that the two strategies yield the same results at 10 digits tolerance:

old <- sqrt(diag(mm %*% vcovmat %*% t(mm)))
new <- sqrt(colSums(t(mm %*% vcovmat) * t(mm)))

identical(round(old, 10), round(new, 10))
# [1] TRUE

Run benchmark:

bench::mark(
    sqrt(diag(mm %*% vcovmat %*% t(mm))),
    sqrt(colSums(t(mm %*% vcovmat) * t(mm))))
# Warning: Some expressions had a GC in every iteration; so filtering is disabled.
# # A tibble: 2 × 6
#   expression                                    min   median `itr/sec` mem_alloc
#   <bch:expr>                               <bch:tm> <bch:tm>     <dbl> <bch:byt>
# 1 sqrt(diag(mm %*% vcovmat %*% t(mm)))        1.26s    1.26s     0.792    1.91GB
# 2 sqrt(colSums(t(mm %*% vcovmat) * t(mm))) 723.55µs   1.77ms   549.       1.95MB
# # … with 1 more variable: gc/sec <dbl>