edunford / tidysynth

A tidy implementation of the synthetic control method in R
Other
98 stars 14 forks source link

Improve speed? #17

Closed etiennebacher closed 2 years ago

etiennebacher commented 2 years ago

Hello, this is a super useful package, I find it much more convenient to use than {Synth}. However, it is also slower. Below is a small benchmark comparing the performances of Synth and tidysynth on the basque data from Abadie & Gardeazabal (2003). I tried to adapt the code in Synth documentation to work with tidysynth (there are some steps from Synth::dataprep() that I didn't adapt but they shouldn't have a large effect, and even if they did it would reduce the gap between Synth and tidysynth):

library(bench)
library(Synth)
#> ##
#> ## Synth Package: Implements Synthetic Control Methods.
#> ## See https://web.stanford.edu/~jhain/synthpage.html for additional information.
library(tidysynth)

data(basque)

##############################
## Synth ##
##############################

synth_ <- function() {

  dataprep.out <-
    dataprep(
      foo = basque
      ,predictors= c("school.illit",
                     "school.prim",
                     "school.med",
                     "school.high",
                     "school.post.high"
                     ,"invest"
      )
      ,predictors.op = c("mean")
      ,dependent     = c("gdpcap")
      ,unit.variable = c("regionno")
      ,time.variable = c("year")
      ,special.predictors = list(
        list("gdpcap",1960:1969,c("mean")),                            
        list("sec.agriculture",seq(1961,1969,2),c("mean")),
        list("sec.energy",seq(1961,1969,2),c("mean")),
        list("sec.industry",seq(1961,1969,2),c("mean")),
        list("sec.construction",seq(1961,1969,2),c("mean")),
        list("sec.services.venta",seq(1961,1969,2),c("mean")),
        list("sec.services.nonventa",seq(1961,1969,2),c("mean")),
        list("popdens",1969,c("mean")))
      ,treatment.identifier  = 17
      ,controls.identifier   = c(2:16,18)
      ,time.predictors.prior = c(1964:1969)
      ,time.optimize.ssr     = c(1960:1969)
      ,unit.names.variable   = c("regionname")
      ,time.plot            = c(1955:1997) 
    )

  dataprep.out$X1["school.high",] <- 
    dataprep.out$X1["school.high",] + 
    dataprep.out$X1["school.post.high",]
  dataprep.out$X1                 <- 
    as.matrix(dataprep.out$X1[
      -which(rownames(dataprep.out$X1)=="school.post.high"),])
  dataprep.out$X0["school.high",] <- 
    dataprep.out$X0["school.high",] + 
    dataprep.out$X0["school.post.high",]
  dataprep.out$X0                 <- 
    dataprep.out$X0[
      -which(rownames(dataprep.out$X0)=="school.post.high"),]

  # 2. make total and compute shares for the schooling catgeories
  lowest  <- which(rownames(dataprep.out$X0)=="school.illit")
  highest <- which(rownames(dataprep.out$X0)=="school.high")

  dataprep.out$X1[lowest:highest,] <- 
    (100 * dataprep.out$X1[lowest:highest,]) /
    sum(dataprep.out$X1[lowest:highest,])
  dataprep.out$X0[lowest:highest,] <-  
    100 * scale(dataprep.out$X0[lowest:highest,],
                center=FALSE,
                scale=colSums(dataprep.out$X0[lowest:highest,])
    )

  # run synth
  synth(data.prep.obj = dataprep.out)

}

##############################
## Tidysynth ##
##############################

tidysynth_ <- function() {
  basque %>% 
    synthetic_control(
      outcome = gdpcap,
      unit = regionno,
      time = year,
      i_unit = 17,
      i_time = 1970
    ) %>% 
    generate_predictor(
      time_window = 1964:1969,
      across(
        all_of(c("school.illit", "school.prim", "school.med", "school.high",
                 "school.post.high", "invest")),
        ~ {
          mean(.x, na.rm = TRUE)
        }
      )
    ) %>% 
    generate_predictor(
      time_window = seq(1961, 1969, by = 2),
      across(
        all_of(c("gdpcap", "sec.agriculture", "sec.energy", "sec.industry",
                 "sec.construction", "sec.services.venta", "sec.services.nonventa",
                 "popdens")),
        ~ {
          mean(.x, na.rm = TRUE)
        }
      )
    ) %>% 
    generate_weights(
      optimization_window = 1960:1969
    ) %>%
    generate_control()
}

bench::mark(
  synth = synth_(),
  tidysynth = tidysynth_(),
  iterations = 10,
  check = FALSE
)
#> [...]
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 2 × 6
#>   expression      min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 synth         3.02s    3.43s    0.285      135MB    0.485
#> 2 tidysynth    14.79s   15.87s    0.0630     598MB    0.687

Created on 2022-06-21 by the reprex package (v2.0.1)

tidysynth is about 4-5x slower than Synth and uses more than 4x more memory. I don't know if it's possible to improve the speed while keeping the same output, but I think it would be useful to make tidysynth faster.

Anyway, thanks again for this package

edunford commented 2 years ago

Thanks for reaching out and for using the package and so sorry for the delayed response! This is awesome. Thanks for delving into this and spending the time to generate these comparisons. So at first glance I'd say the test you ran isn't a fair comparison. Synth doesn't automatically generate the placebo cases for the inference step outlined in the method. You'd have to do that manually.

If you want to emulate the baseline behavior in the Synth package using tidysynth, you'd need to set generate_placebos= in the instantiation function synthetic_control() to FALSE. The assumption in the tidysynth package is the method is going to be used to make some inference and thus sets users up to do that as easy as possible. This, of course, comes at the cost of having a longer run time.

Using the above example, but setting generate_placebos=FALSE, e.g.

    synthetic_control(
      outcome = gdpcap,
      unit = regionno,
      time = year,
      i_unit = 17,
      i_time = 1970,
      generate_placebos=FALSE #<<<<
    ) %>% 

I get the following run stats on my machine -- so roughly 2x slower. But then again, tidysynth is just a wrapper around the main optimization function in Synth. So this added compute time isn't surprising since there is added overhead given storing all the relevant output. This, naturally, also increases the overall memory allocation.

# A tibble: 2 × 13
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 synth         3.31s     3.8s     0.262     133MB    0    
2 tidysynth     7.49s    8.05s     0.123     540MB    0.123
# … with 7 more variables: n_itr <int>, n_gc <dbl>,
#   total_time <bch:tm>, result <list>, memory <list>,
#   time <list>, gc <list>

Now let's consider the alternative. That is, let's compute all the donors as well as the treated and store the data (just to get a rough estimate of size). Likewise, we'll set generate_placebos=TRUE in the tidysynth package. This will be a more comparable comparison.

library(bench)
library(Synth)
library(tidysynth)

data(basque)

##############################
## Synth ##
##############################

synth_ <- function() {

  # Units used as treatment and control
  relevant_units = c(2:16,18)
  store_output = list()

  # Iterate through each relevant unit (treated and the donors)
  for (i in relevant_units){
    # Pull out the treated unit (each donor gets a shot to lead the show)
    treated = i
    # The rest are backup dancers. 
    donors = relevant_units[which(i != relevant_units)]

    # Prep the data
    dataprep.out <-
      dataprep(
        foo = basque
        ,predictors= c("school.illit",
                       "school.prim",
                       "school.med",
                       "school.high",
                       "school.post.high"
                       ,"invest"
        )
        ,predictors.op = c("mean")
        ,dependent     = c("gdpcap")
        ,unit.variable = c("regionno")
        ,time.variable = c("year")
        ,special.predictors = list(
          list("gdpcap",1960:1969,c("mean")),                            
          list("sec.agriculture",seq(1961,1969,2),c("mean")),
          list("sec.energy",seq(1961,1969,2),c("mean")),
          list("sec.industry",seq(1961,1969,2),c("mean")),
          list("sec.construction",seq(1961,1969,2),c("mean")),
          list("sec.services.venta",seq(1961,1969,2),c("mean")),
          list("sec.services.nonventa",seq(1961,1969,2),c("mean")),
          list("popdens",1969,c("mean")))
        ,treatment.identifier  = 17
        ,controls.identifier   = c(donors,treated)
        ,time.predictors.prior = c(1964:1969)
        ,time.optimize.ssr     = c(1960:1969)
        ,unit.names.variable   = c("regionname")
        ,time.plot            = c(1955:1997) 
      )

    dataprep.out$X1["school.high",] <- 
      dataprep.out$X1["school.high",] + 
      dataprep.out$X1["school.post.high",]
    dataprep.out$X1                 <- 
      as.matrix(dataprep.out$X1[
        -which(rownames(dataprep.out$X1)=="school.post.high"),])
    dataprep.out$X0["school.high",] <- 
      dataprep.out$X0["school.high",] + 
      dataprep.out$X0["school.post.high",]
    dataprep.out$X0                 <- 
      dataprep.out$X0[
        -which(rownames(dataprep.out$X0)=="school.post.high"),]

    # 2. make total and compute shares for the schooling catgeories
    lowest  <- which(rownames(dataprep.out$X0)=="school.illit")
    highest <- which(rownames(dataprep.out$X0)=="school.high")

    dataprep.out$X1[lowest:highest,] <- 
      (100 * dataprep.out$X1[lowest:highest,]) /
      sum(dataprep.out$X1[lowest:highest,])
    dataprep.out$X0[lowest:highest,] <-  
      100 * scale(dataprep.out$X0[lowest:highest,],
                  center=FALSE,
                  scale=colSums(dataprep.out$X0[lowest:highest,])
      )

    # run synth
    output = synth(data.prep.obj = dataprep.out)

    # Roughly store the data (just to get a sense of memory costs)
    out_list = list(treated = treated, donors = donors, dataprep.out = dataprep.out, output= output)
    store_output = c(store_output,out_list)
  }

}

##############################
## Tidysynth ##
##############################

tidysynth_ <- function() {
  basque %>% 
    synthetic_control(
      outcome = gdpcap,
      unit = regionno,
      time = year,
      i_unit = 17,
      i_time = 1970,
      generate_placebos=TRUE
    ) %>% 
    generate_predictor(
      time_window = 1964:1969,
      across(
        all_of(c("school.illit", "school.prim", "school.med", "school.high",
                 "school.post.high", "invest")),
        ~ {
          mean(.x, na.rm = TRUE)
        }
      )
    ) %>% 
    generate_predictor(
      time_window = seq(1961, 1969, by = 2),
      across(
        all_of(c("gdpcap", "sec.agriculture", "sec.energy", "sec.industry",
                 "sec.construction", "sec.services.venta", "sec.services.nonventa",
                 "popdens")),
        ~ {
          mean(.x, na.rm = TRUE)
        }
      )
    ) %>% 
    generate_weights(
      optimization_window = 1960:1969
    ) %>%
    generate_control()
}

bench::mark(
  synth = synth_(),
  tidysynth = tidysynth_(),
  iterations = 10,
  check = FALSE
)

You'll see that in this scenario, Synth does win-out. It's faster. Like a backpacker with no gear, it's easier to move when you're not lugging around any amenities (e.g. storing output so it's easier to vet, understand, and plot the output).

My philosophy on this one is that this is an optimization at the margins. The run time would be an issue if this was a large-n method and it didn't scale with larger inputs. But really, this is a small-n method hoping to extract quantitative insight (and potentially a "causal" estimate) from a qualitative design. So I'm not entirely convinced the optimization investment is worth it.

That said, if you have ideas of where and how to optimize, send me a pull request with the implementation and the test comparisons. I'd love to make everything faster if you have some ideas on where and how!

Thanks again!

etiennebacher commented 2 years ago

Thank you very much for the thorough answer! Being 2x slower than Synth is acceptable for me, I was really surprised when I saw it was 5x times slower but indeed it was my mistake, I forgot to add generate_placebos = FALSE in my code...