kosukeimai / MatchIt

R package MatchIt
211 stars 41 forks source link

long execution time for large data #204

Open zbig117 opened 2 months ago

zbig117 commented 2 months ago

MatchIt takes almost 4 hours for the 1:1 greedy matching:

> Sys.time()
[1] "2024-09-17 17:47:04 EDT"
> 
> res<-MatchIt::matchit(formula = T ~ PS
+                       ,data = dat
+                       ,method = "nearest"
+                       ,caliper=.2
+                       ,std.caliper = TRUE
+                       ,distance = logitnorm::logit(unlist(dat['PS'],use.names = F))
+                       ,ratio = 1
+                       ,exact = NULL
+                       ,replace=FALSE
+                       ,verbose=TRUE
+ )
Nearest neighbor matching... 
0%   10   20   30   40   50   60   70   80   90   100%
[----|----|----|----|----|----|----|----|----|----|
**************************************************|
Calculating matching weights... Done.
Warning message:
Fewer control units than treated units; not all treated units will get a match. 
> 
> Sys.time()
[1] "2024-09-17 21:24:12 EDT"

For the same 1mln data and using a simple script which I believe does the same (dat1 - treated, dat0 - controls) it takes 30mins:

> sd=var(c(dat1$lps,dat0$lps))^.5
> clpr=.2*sd
> mtch1=mtch0=NULL
> Sys.time()
[1] "2024-09-18 14:38:29 EDT"
> for(i1 in 1:nrow(dat1)){
+   whmin=which.min(abs(dat1$lps[i1]-dat0$lps))
+   mindist=abs(dat1$lps[i1]-dat0$lps[whmin])
+   if(mindist>clpr)next
+   mtch1=c(mtch1,dat1$subjid[i1])
+   mtch0=c(mtch0,dat0$subjid[whmin])
+   dat0$lps[whmin]=Inf
+ }
> Sys.time()
[1] "2024-09-18 15:06:09 EDT"

where is the catch?

ngreifer commented 1 month ago

I saw your email about this and have not had time to investigate further. If your code works for you, then use it! MatchIt provides extremely general code that can handle lots of combinations of matching parameters, but that might add computation time. I am surprised that it would add so much time so I do want to investigate improving its performance, but that is a low priority for me right now. Nearest neighbor matching is not optimized for large datasets, but other matching methods like subclassification and generalized full matching (method = "quick") are, so I recommend those if you want speed.

ngreifer commented 1 month ago

I took some time to dig deeper into the code for nearest neighbor matching and was able to find many opportunities for improvement. These are available in the development version, which you can install using pak::pkg_install("ngreifer/MatchIt"). It'll be on CRAN within the next few weeks.

These optimizations dramatically improve speed for large datasets. A dataset of 10 million records now takes less than a minute to match. I also added an estimate of the time remaining to the progress bar so you can see how long it takes. Below are some benchmarks comparing the performance of the old and new versions of the package. I would appreciate it if you could try out the new version on your dataset and see if you get reasonable results.

n <- 1e4
p <- runif(n, 0, .4)
a <- rbinom(n, 1, p)
d <- data.frame(p, a)

# 1:1 matching w/o replacement, m.order = "largest" (default)
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE)$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE)$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr        min        lq      mean    median        uq      max neval cld
#>   old 586.885631 595.90136 615.85469 603.73874 628.39318 667.9139    10  a 
#>   new   9.828595  10.11494  42.00253  10.37078  10.94577 639.4152    20   b

# 1:1 matching w/o replacement, m.order = "closest"
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 m.order = "closest")$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 m.order = "closest")$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr        min         lq      mean     median        uq       max neval cld
#>   old 1656.21480 1694.83668 1946.9899 1855.31458 2084.8589 2826.2016    10  a 
#>   new   77.11349   80.99866  114.0754   88.54409  153.4945  233.2797    20   b

# 1:1 matching w/ replacement
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 replace = TRUE)$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 replace = TRUE)$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr       min        lq      mean    median        uq       max neval cld
#>   old 386.55556 389.27397 439.70722 424.47301 467.74070 584.92790    10  a 
#>   new  10.87157  10.96345  11.21426  11.17664  11.35386  11.89584    20   b

# 2:1 matching w/o replacement, m.order = "largest" (default)
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 ratio = 2)$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 ratio = 2)$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr     min        lq       mean     median        uq        max neval cld
#>   old 936.416 978.60634 1013.04795 1001.05834 1060.8694 1081.10566    10  a 
#>   new  17.059  17.34317   17.67773   17.55189   17.8875   18.71939    20   b

# 2:1 matching w/o replacement, m.order = "closest"
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 m.order = "closest",
                                 ratio = 2)$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 m.order = "closest",
                                 ratio = 2)$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr       min        lq      mean    median        uq      max neval cld
#>   old 1497.0966 1592.9533 1720.1074 1662.0881 1778.2476 2196.559    10  a 
#>   new  575.2696  579.9569  659.0618  624.9534  669.6305 1040.166    20   b

# 2:1 matching w/ replacement
microbenchmark(old = matchit_old(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 replace = TRUE,
                                 ratio = 2)$match.matrix,
               new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE,
                                 replace = TRUE,
                                 ratio = 2)$match.matrix,
               times = c(10, 20), check = "equivalent")
#> Unit: milliseconds
#>  expr       min        lq      mean    median        uq       max neval cld
#>   old 629.89802 683.57718 691.95727 690.15760 709.90165 765.65008    10  a 
#>   new  13.88834  14.02107  14.39571  14.11364  14.43978  16.08355    20   b

# Large dataset
n <- 1e6
p <- runif(n, 0, .4)
a <- rbinom(n, 1, p)
d <- data.frame(p, a)

microbenchmark(new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE)$match.matrix,
               times = 20)
#> Unit: seconds
#>  expr      min       lq     mean   median       uq      max neval
#>   new 1.126528 1.229169 1.365787 1.354254 1.440692 1.923601    20

# Large dataset
n <- 1e7
p <- runif(n, 0, .4)
a <- rbinom(n, 1, p)
d <- data.frame(p, a)

microbenchmark(new = matchit_new(a ~ p, data = d, distance = d$p,
                                 caliper = .01, std.caliper = FALSE)$match.matrix,
               times = 5)
#> Unit: seconds
#>  expr      min       lq     mean   median       uq      max neval
#>   new 15.16614 16.86033 17.23104 16.93776 17.07013 20.12084     5

Created on 2024-10-24 with reprex v2.1.1