grantmcdermott / etwfe

Extended two-way fixed effects
https://grantmcdermott.com/etwfe/
Other
50 stars 11 forks source link

marginaleffects slow in large data sets #18

Closed frederickluser closed 1 year ago

frederickluser commented 1 year ago

Dear Grant,

The marginaleffects::marginaleffects part in emfx becomes very slow in large datasets (> 500k), especially when the number of periods increases. At some point, it does not run in a sensible time anymore (running time explodes exponentially).

Currently, etwfe runs marginaleffects over the entire dat dataset. But that is actually not necessary. The results will be (almost) equivalent if you first collapse the data set for all period-cohort combinations and take weighted means!

I added in my fork the part below that collapses the data. With some simulated data (1 mio. obs, 10 periods), adding this block improves running time by a factor of 4 (100 sec. vs 400 sec.) and this only increases exponentially with larger datasets.

What do you think about this? Greetings !

# define formulas (needs an "if-else" with / without xvar)
form_count = stats::as.formula(paste(".", " ~", gvar, "+", tvar, "+", xvar))
form_data = stats::as.formula(paste(".", " ~", gvar, "+", tvar, "+", xvar, "+ .Dtreat"))

# calculate weights
dat_weights = aggregate(form_count, data = subset(dat, .Dtreat == 1), FUN = length)[c(gvar, tvar, xvar, ".Dtreat")]
names(dat_weights)[names(dat_weights) == ".Dtreat"] = "n"

# collapse the data 
dat = aggregate(form_data, data = subset(dat, .Dtreat == 1), FUN = mean, na.rm = TRUE)

# merge the weights onto the collapsed data
dat = merge(dat, dat_weights, all.x = T)

# [...]
mfx = marginaleffects::marginaleffects(
  object,
  newdata = dat, # the collapsed data
  wts = "n", # the cohort-period cell size as weights
  variables = ".Dtreat", 
  by = c(by_var, xvar),
  ...
grantmcdermott commented 1 year ago

Ooh, I like this idea. I have to run off to work, but will definitely give it some thought. Tagging @vincentarelbundock in case he has any thoughts.

grantmcdermott commented 1 year ago

Also, it probably won't make much of a difference (i.e., it's not the main bottleneck). But we can think about using data.table for the collapsing, since it's already imported as a marginaleffects dependency. Another option would be collapse.

vincentarelbundock commented 1 year ago

I'd be interested to see some profiling. My hunch is that 80%+ of that time is spent calling predict.fixest(), and there's unfortunately not much I can do about that. If that's the case, then collapsing when possible seems like a smart way forward.

frederickluser commented 1 year ago

I ran now all 3 versions that I have: the current package, the approach with aggregating the data, using base and data.table. I push the data.table version to my fork. I use simulated data (1 mio obs., 10 periods, with heterogeneity between periods and cohorts).

First, running time is 4x smaller if we first aggregate the data, while the results are the same. Switching from base to data.table does not save much running time, at least not with 1 mio. observations. Yet, it saves a couple of lines of code.

# emfx: the current package
   type    term                 contrast     dydx   std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.32764 0.004613442  3972.660       0 18.31859  18.33668  5.659235     5.659235    0.6579023   0    TRUE
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.58526 0.002911542 12565.595       0 36.57956  36.59097 10.664235    10.664235    0.6585625   1    TRUE
3 response .Dtreat mean(TRUE) - mean(FALSE) 54.96559 0.006885473  7982.834       0 54.95209  54.97908 15.669236    15.669236    0.6592226   2    TRUE
397.75 sec elapsed

# emfx: first aggregating the data (base)
      type    term                 contrast     dydx   std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.33025 0.004613759  3972.953       0 18.32121  18.33929  5.659235     5.659235    0.6579023   0    TRUE
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.67197 0.002921049 12554.382       0 36.66624  36.67769 10.664235    10.664235    0.6585625   1    TRUE
3 response .Dtreat mean(TRUE) - mean(FALSE) 55.01368 0.006890785  7983.660       0 55.00018  55.02719 15.669236    15.669236    0.6592226   2    TRUE
109.01 sec elapsed

# emfx: first aggregating the data (data.table)
      type    term                 contrast     dydx   std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.33025 0.004613759  3972.953       0 18.32121  18.33929  5.659235     5.659235    0.6579023   0       1
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.67197 0.002921049 12554.383       0 36.66624  36.67769 10.664235    10.664235    0.6585625   1       1
3 response .Dtreat mean(TRUE) - mean(FALSE) 55.01368 0.006890785  7983.660       0 55.00018  55.02719 15.669236    15.669236    0.6592226   2       1
108.54 sec elapsed

Then, see profiling from Rprof for emfx. Does this give you any helpful information?

# The current package
                         self.time self.pct total.time total.pct
"cpp_factor_matrix"          45.10    17.68      45.10     17.68
"as.vector"                  42.76    16.76      59.02     23.13
"factor"                     40.46    15.86      40.48     15.87
"*"                          31.04    12.17      31.04     12.17
"cbind"                      18.44     7.23      18.44      7.23
"cpp_get_fe_gnl"             16.94     6.64      16.94      6.64
"%*%"                        16.26     6.37      16.26      6.37
"cpp_quf_gnl"                 5.54     2.17       5.54      2.17
"[.data.table"                5.38     2.11      11.22      4.40
"bmerge"                      3.96     1.55       4.46      1.75
[...]

# First aggregating the data (with base code)
                       self.time self.pct total.time total.pct
"cpp_factor_matrix"         32.66    46.76      32.66     46.76
"cpp_get_fe_gnl"            15.82    22.65      15.82     22.65
"cpp_quf_gnl"                2.62     3.75       2.62      3.75
"unlist"                     2.38     3.41       3.40      4.87
"fixef.fixest"               1.72     2.46      22.38     32.04
"to_integer"                 1.66     2.38       8.32     11.91
"FUN"                        1.60     2.29      68.46     98.02
"quickUnclassFactor"         1.48     2.12       4.20      6.01
"is.na"                      1.38     1.98       1.38      1.98
"any"                        0.88     1.26       0.88      1.26
[...]
vincentarelbundock commented 1 year ago

Yep, that's useful: none of the top functions are called by marginaleffects, so it's all fixest. To compute standard errors by the delta method, we work with a simple epsilon difference approach: (predict(x+h)-predict(x))/h. This requires us to call predict.fixest() twice per coefficient in the model.

IIRC, etwfe setups up models with many coefficients (at least one per time period?), which means there's probably a lot of predict() calls made behind the scenes. All the data manipulation that marginaleffects does typically takes no time relative to that.

Setting vcov=FALSE would surely speed things up considerably...

frederickluser commented 1 year ago

I see. Thanks Vincent for the explanations.

Yes, etwfe estimates a treatment effect for every period x cohort - combinations. Hence, estimates many coefficients. But using a collapsed dataset as newdata in marginaleffects seems to help quite a lot, while results on "the level of interest" remain basically the same.

(EDIT: And yes, with vcov=FALSE, it runs in < 2 sec)

grantmcdermott commented 1 year ago

Thanks both for your thoughts and contributions here.

I wonder if there's some shortcut to exploit given that we're predicting off the same set of fixed effectss. OTOH, probably won't yield much upside given that most of the time is spend calculating the SEs. Speaking of which...

The vignette (and docs) now include a dedicated Performance tips section that covers most of what we've discussed in this thread. Happy to add more ideas of anyone has any, but I think users have pretty reasonable options at their disposal now.

vincentarelbundock commented 1 year ago

Vignette looks great!