py-econometrics / duckreg

Every big regression is a small regression with weights.
MIT License
25 stars 4 forks source link

R version? #10

Open grantmcdermott opened 2 months ago

grantmcdermott commented 2 months ago

Not sure of your motivation for polyglot support, but I quickly knocked out a basic R version this evening. Only the main function, so no Mundlak, DD, CRV, etc. But perhaps it could serve as a helpful base to build on.

It seems to be doing everything correctly (including on some NYC taxi data, which I had trouble getting to work via the Python frontend for some reason or another; I might be doing something wrong). Bonus: small dependency footprint.

library(Matrix)
library(duckdb)

duckreg = function(
    fml,
    conn = NULL,
    table = NULL,
    vcov = "hc1"
    ) {

      if (is.null(conn)) {
         conn = dbConnect(duckdb(), shutdown = TRUE)
      }

      # vars of interest
      vars = all.vars(fml)
      yvar = vars[1]
      xvars = vars[-1]

      # query string
      query_string = paste0(
         "
         WITH cte AS (
            SELECT
               ",
         paste(xvars, collapse = ", "), ",
               COUNT(*) AS n,
               SUM(", yvar, ") as sum_Y,
               SUM(POW(", yvar, ", 2)) as sum_Y_sq,
            FROM ", table, "
            GROUP BY ALL
         )
         FROM cte
         SELECT
            *,
            sum_Y / n AS mean_Y,
            sqrt(n) AS wts
         "
      )

      # fetch data
      compressed_dat = dbGetQuery(conn = conn, query_string)

      # design and outcome matrices
      X = sparse.model.matrix(reformulate(c(xvars)), compressed_dat)
      Y = compressed_dat[, "mean_Y"]
      Xw = X * compressed_dat[["wts"]]
      Yw = Y * compressed_dat[["wts"]]

      # beta values
      betahat = chol2inv(chol(crossprod(Xw))) %*% crossprod(Xw, Yw)

      # standard errors (currently only HC1)
      if (vcov == tolower("hc1")) {
         n = compressed_dat[["n"]]
         yprime = compressed_dat[["sum_Y"]]
         yprimeprime = compressed_dat[["sum_Y_sq"]]
         # Compute yhat
         yhat = X %*% betahat
         # Compute rss_g
         rss_g = (yhat^2) * n - 2 * yhat * yprime + yprimeprime
         # Compute vcov components
         bread = solve(crossprod(X, Diagonal(x = n) %*% X))
         meat = crossprod(X, Diagonal(x = as.vector(rss_g)) %*% X)
         n_nk = sum(n) / (sum(n) - ncol(X))
         vcov = n_nk * (bread %*% meat %*% bread)
         # grab SEs
         ses = sqrt(diag(vcov))
      }

      # return object
      ret = cbind(estimate = betahat[, 1], std.error = ses)

      return(ret)
}

Example 1: Synthetic data

Re-using the `large_dataset.db` synthetic data from the [Intro notebook](https://github.com/apoorvalal/duckreg/blob/master/notebooks/introduction.ipynb) ```r system.time({ mod = duckreg(Y ~ D | f1 + f2, con, "data") }) #> user system elapsed #> 1.222 0.023 0.116 mod #> estimate std.error #> (Intercept) -0.0002736823 8.620003e-04 #> D 0.9993471218 6.324180e-04 #> f1 1.0000353622 5.484845e-05 #> f2 2.0000668533 5.482098e-05 dbDisconnect(con) ```

Example 2: NYC taxi data

Regressing 3 months of NYC taxi data. See [here](https://grantmcdermott.com/duckdb-polars/requirements.html#nyc-taxi-data) for download instructions. Requires a bit more setup getting the data ready, but this time I'll use an (ephemeral) in-memory database rather than a persistent one just to demonstrate. ```r con = dbConnect(duckdb(), shutdown = TRUE) # create the table in DuckDB's memory bank dbExecute( con, " CREATE TABLE taxi AS FROM 'nyc-taxi/**/*.parquet' SELECT tip_amount, trip_distance, passenger_count, vendor_id, payment_type, dropoff_at, dayofweek(dropoff_at) AS dofw WHERE year = 2012 AND CAST(month AS INTEGER) <= 3 " ) # our formula fml = tip_amount ~ trip_distance + passenger_count | dofw + vendor_id + payment_type # run the model system.time({ mod = duckreg(fml, con, "taxi") }) #> user system elapsed #> 12.279 1.109 1.587 mod_taxi #> estimate std.error #> (Intercept) 1.5466728530 7.079425e-04 #> trip_distance 0.2089613849 2.142470e-04 #> passenger_count -0.0070995354 1.450711e-04 #> dofw 0.0008833547 9.034392e-05 #> vendor_idVTS -0.0336614402 3.926514e-04 #> payment_typeCSH -2.0457657287 3.658931e-04 #> payment_typeDIS -2.1900513745 1.676851e-02 #> payment_typeNOC -2.0951645887 9.272809e-03 #> payment_typeUNK 0.6852192686 2.558176e-02 dbDisconnect(con) ``` For comparison, here is the native `fixest` output after reading in the full data into memory. The `fixest` model takes about 13 seconds to run on my machine... so a ~8x speedup in this case and the point estimates are slightly off. OTOH, this doesn't include the difference in data I/O time and it's obviously a more complicated model. ```r summary(feols_reg, vcov = 'hc1') #> OLS estimation, Dep. Var.: tip_amount #> Observations: 46,099,576 #> Fixed-effects: dofw: 7, vendor_id: 2, payment_type: 5 #> Standard-errors: Heteroskedasticity-robust #> Estimate Std. Error t value Pr(>|t|) #> trip_distance 0.209141 0.000214 975.3825 < 2.2e-16 *** #> passenger_count -0.006061 0.000145 -41.7426 < 2.2e-16 *** #> --- #> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 #> RMSE: 1.23099 Adj. R2: 0.516743 #> Within R2: 0.229813 ```
apoorvalal commented 2 months ago

this is great; thanks! maybe add the glue package as a dependency to make the string interpolation a bit more readable (and make mundlak etc easier to do)?

Feel free to send this as a PR; I haven't ever managed a python and r package in the same repo [perhaps you have - lmk about folder structure if so]

grantmcdermott commented 2 months ago

Cool, cool. I'm up against work deadlines before I head out for a mini-break, so can't promise anything right away... but maybe I get another jolt of motivation this evening. Who knows?

RE: the general structure. I believe it should be as simple as moving each language to its own sub-folder and have users install against the relevant sub-targets. I'm less sure about the CI infra and testsuite. We may have to configure separate workflow files.

apoorvalal commented 2 months ago

honestly it might be easier to just have a separate repo for the R package; most cases i've seen with py/r living in the same repo is when they rely on the same underlying C++ code or sth, which is not the case here. I'd be happy to contribute to a self-contained R package that you could spin up on your own.