dheimgartner / OPSR

GNU General Public License v3.0
0 stars 0 forks source link

Check `predict(..., type = "unlog-response"`) again #9

Closed dheimgartner closed 1 week ago

dheimgartner commented 1 week ago
devtools::load_all()
sim_dat <- opsr_simulate()
dat <- sim_dat$data
dat$log_yo <- log(dat$yo + 1)

dat <- na.omit(dat)

fit <- opsr( ys | yo ~ xs1 + xs2 | xo1 + xo2, dat)
summary(fit)

fit_log <- opsr(ys | log_yo ~ xs1 + xs2 | xo1 + xo2, dat)
summary(fit_log)

texreg::screenreg(list(fit, fit_log))

boxplot(dat$yo)
boxplot(dat$log_yo)

compare_fun <- function(fit, group, type = "unlog-response") {
  p <- predict(fit, group = group, counterfact = group, type = type)
  compare <- data.frame(
    yo = subset(dat, subset = ys == group, select = yo),
    yo_pred = na.omit(p)
  )

  title <- paste0("'group' = ", group)
  plot(compare, xlim = c(0, max(compare)), ylim = c(0, max(compare)), main = title)
  abline(a = 0, b = 1, col = "red")
}

par(mfrow = c(1, 3))
compare_fun(fit_log, group = 1)
compare_fun(fit_log, group = 2)
compare_fun(fit_log, group = 3)

par(mfrow = c(1, 3))
compare_fun(fit, group = 1, type = "response")
compare_fun(fit, group = 2, type = "response")
compare_fun(fit, group = 3, type = "response")

It predicts very extreme values (way out of range of original yo)...

dheimgartner commented 1 week ago
devtools::load_all()
sim_dat <- opsr_simulate()
dat <- sim_dat$data
dat$yo <- dat$yo + 10  # shift to avoid log of neg
dat$log_yo <- log(dat$yo + 1)

fit <- opsr(ys | yo ~ xs1 + xs2 | xo1 + xo2, dat)
summary(fit)

fit_log <- opsr(ys | log_yo ~ xs1 + xs2 | xo1 + xo2, dat)
summary(fit_log)

texreg::screenreg(list(fit, fit_log))

boxplot(dat$yo)
boxplot(dat$log_yo)

compare_fun <- function(fit, group, type = "unlog-response") {
  p <- predict(fit, group = group, type = type)  # counterfact = group
  compare <- data.frame(
    yo = subset(dat, subset = ys == group, select = yo),
    yo_pred = na.omit(p)
  )

  title <- paste0("'group' = ", group)
  plot(compare, xlim = c(0, max(compare)), ylim = c(0, max(compare)), main = title)
  abline(a = 0, b = 1, col = "red")
}

par(mfrow = c(1, 3))
compare_fun(fit_log, group = 1)
compare_fun(fit_log, group = 2)
compare_fun(fit_log, group = 3)