Closed frederickluser closed 1 year ago
Ooh, I like this idea. I have to run off to work, but will definitely give it some thought. Tagging @vincentarelbundock in case he has any thoughts.
Also, it probably won't make much of a difference (i.e., it's not the main bottleneck). But we can think about using data.table for the collapsing, since it's already imported as a marginaleffects dependency. Another option would be collapse.
I'd be interested to see some profiling. My hunch is that 80%+ of that time is spent calling predict.fixest()
, and there's unfortunately not much I can do about that. If that's the case, then collapsing when possible seems like a smart way forward.
I ran now all 3 versions that I have: the current package, the approach with aggregating the data, using base
and data.table
. I push the data.table
version to my fork.
I use simulated data (1 mio obs., 10 periods, with heterogeneity between periods and cohorts).
First, running time is 4x smaller if we first aggregate the data, while the results are the same. Switching from base to data.table does not save much running time, at least not with 1 mio. observations. Yet, it saves a couple of lines of code.
# emfx: the current package
type term contrast dydx std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.32764 0.004613442 3972.660 0 18.31859 18.33668 5.659235 5.659235 0.6579023 0 TRUE
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.58526 0.002911542 12565.595 0 36.57956 36.59097 10.664235 10.664235 0.6585625 1 TRUE
3 response .Dtreat mean(TRUE) - mean(FALSE) 54.96559 0.006885473 7982.834 0 54.95209 54.97908 15.669236 15.669236 0.6592226 2 TRUE
397.75 sec elapsed
# emfx: first aggregating the data (base)
type term contrast dydx std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.33025 0.004613759 3972.953 0 18.32121 18.33929 5.659235 5.659235 0.6579023 0 TRUE
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.67197 0.002921049 12554.382 0 36.66624 36.67769 10.664235 10.664235 0.6585625 1 TRUE
3 response .Dtreat mean(TRUE) - mean(FALSE) 55.01368 0.006890785 7983.660 0 55.00018 55.02719 15.669236 15.669236 0.6592226 2 TRUE
109.01 sec elapsed
# emfx: first aggregating the data (data.table)
type term contrast dydx std.error statistic p.value conf.low conf.high predicted predicted_hi predicted_lo sex .Dtreat
1 response .Dtreat mean(TRUE) - mean(FALSE) 18.33025 0.004613759 3972.953 0 18.32121 18.33929 5.659235 5.659235 0.6579023 0 1
2 response .Dtreat mean(TRUE) - mean(FALSE) 36.67197 0.002921049 12554.383 0 36.66624 36.67769 10.664235 10.664235 0.6585625 1 1
3 response .Dtreat mean(TRUE) - mean(FALSE) 55.01368 0.006890785 7983.660 0 55.00018 55.02719 15.669236 15.669236 0.6592226 2 1
108.54 sec elapsed
Then, see profiling from Rprof
for emfx
. Does this give you any helpful information?
# The current package
self.time self.pct total.time total.pct
"cpp_factor_matrix" 45.10 17.68 45.10 17.68
"as.vector" 42.76 16.76 59.02 23.13
"factor" 40.46 15.86 40.48 15.87
"*" 31.04 12.17 31.04 12.17
"cbind" 18.44 7.23 18.44 7.23
"cpp_get_fe_gnl" 16.94 6.64 16.94 6.64
"%*%" 16.26 6.37 16.26 6.37
"cpp_quf_gnl" 5.54 2.17 5.54 2.17
"[.data.table" 5.38 2.11 11.22 4.40
"bmerge" 3.96 1.55 4.46 1.75
[...]
# First aggregating the data (with base code)
self.time self.pct total.time total.pct
"cpp_factor_matrix" 32.66 46.76 32.66 46.76
"cpp_get_fe_gnl" 15.82 22.65 15.82 22.65
"cpp_quf_gnl" 2.62 3.75 2.62 3.75
"unlist" 2.38 3.41 3.40 4.87
"fixef.fixest" 1.72 2.46 22.38 32.04
"to_integer" 1.66 2.38 8.32 11.91
"FUN" 1.60 2.29 68.46 98.02
"quickUnclassFactor" 1.48 2.12 4.20 6.01
"is.na" 1.38 1.98 1.38 1.98
"any" 0.88 1.26 0.88 1.26
[...]
Yep, that's useful: none of the top functions are called by marginaleffects
, so it's all fixest
. To compute standard errors by the delta method, we work with a simple epsilon difference approach: (predict(x+h)-predict(x))/h. This requires us to call predict.fixest()
twice per coefficient in the model.
IIRC, etwfe
setups up models with many coefficients (at least one per time period?), which means there's probably a lot of predict()
calls made behind the scenes. All the data manipulation that marginaleffects
does typically takes no time relative to that.
Setting vcov=FALSE
would surely speed things up considerably...
I see. Thanks Vincent for the explanations.
Yes, etwfe
estimates a treatment effect for every period x cohort - combinations. Hence, estimates many coefficients.
But using a collapsed dataset as newdata
in marginaleffects
seems to help quite a lot, while results on "the level of interest" remain basically the same.
(EDIT: And yes, with vcov=FALSE
, it runs in < 2 sec)
Thanks both for your thoughts and contributions here.
I wonder if there's some shortcut to exploit given that we're predicting off the same set of fixed effectss. OTOH, probably won't yield much upside given that most of the time is spend calculating the SEs. Speaking of which...
The vignette (and docs) now include a dedicated Performance tips section that covers most of what we've discussed in this thread. Happy to add more ideas of anyone has any, but I think users have pretty reasonable options at their disposal now.
Vignette looks great!
Dear Grant,
The
marginaleffects::marginaleffects
part inemfx
becomes very slow in large datasets (> 500k), especially when the number of periods increases. At some point, it does not run in a sensible time anymore (running time explodes exponentially).Currently,
etwfe
runsmarginaleffects
over the entiredat
dataset. But that is actually not necessary. The results will be (almost) equivalent if you first collapse the data set for all period-cohort combinations and take weighted means!I added in my fork the part below that collapses the data. With some simulated data (1 mio. obs, 10 periods), adding this block improves running time by a factor of 4 (100 sec. vs 400 sec.) and this only increases exponentially with larger datasets.
What do you think about this? Greetings !