bertcarnell / tornado

tornado plots for model sensitivity analysis
https://bertcarnell.github.io/tornado/
GNU General Public License v3.0
5 stars 0 forks source link

Could tornado allow for weights in a linear model? #13

Closed ghobro closed 1 month ago

ghobro commented 1 month ago

I have a very large dataset upon which I run a linear model. I want to be able to summarise the data by different grouping variables to make it more computationally streamlined. However, the tornado::tornado() function won't make allowances for the weights when presenting the sensitivity analysis.

bertcarnell commented 1 month ago

Tornado allows for weights in the linear model, and the resulting sensitivities change with weighting as you can see below. If this is not what you were looking for, then please send a small example so that I can better understand your need.

require(tornado)
#> Loading required package: tornado
gtest_wt <- lm(mpg ~ cyl*wt*hp, data = mtcars, weights = rep(1:2, nrow(mtcars) / 2))
torn_wt <- tornado(gtest_wt, type = "PercentChange", alpha = 0.10)

gtest <- lm(mpg ~ cyl*wt*hp, data = mtcars)
torn <- tornado(gtest, type = "PercentChange", alpha = 0.10)

cbind(torn_wt$data$plotdat, torn$data$plotdat)
#>   variable      value Level variable      value Level
#> 1      cyl -0.1263769   90%      cyl -0.2142005   90%
#> 2       wt  1.3982777   90%       wt  1.4023713   90%
#> 3       hp  0.6818810   90%       hp  0.7165715   90%
#> 4      cyl  0.1263769  110%      cyl  0.2142005  110%
#> 5       wt -1.3982777  110%       wt -1.4023713  110%
#> 6       hp -0.6818810  110%       hp -0.7165715  110%

gtest_wt <- lm(mpg ~ cyl*wt*hp, data = mtcars, weights = rep(1:2, nrow(mtcars) / 2))
torn_wt <- tornado(gtest_wt, type = "ranges")

gtest <- lm(mpg ~ cyl*wt*hp, data = mtcars)
torn <- tornado(gtest, type = "ranges")

cbind(torn_wt$data$plotdat, torn$data$plotdat)
#>   variable      value Level variable      value Level
#> 1      cyl -0.4467869 Lower      cyl -0.7572746 Lower
#> 2       wt  7.4069926 Lower       wt  7.4286776 Lower
#> 3       hp  4.4015756 Lower       hp  4.6255038 Lower
#> 4      cyl  0.3701948 Upper      cyl  0.6274561 Upper
#> 5       wt -9.5909526 Upper       wt -9.6190314 Upper
#> 6       hp -8.7537605 Upper       hp -9.1991043 Upper

gtest_wt <- lm(mpg ~ cyl*wt*hp, data = mtcars, weights = rep(1:2, nrow(mtcars) / 2))
torn_wt <- tornado(gtest_wt, type = "percentiles", alpha = 0.10)

gtest <- lm(mpg ~ cyl*wt*hp, data = mtcars)
torn <- tornado(gtest, type = "percentiles", alpha = 0.10)

cbind(torn_wt$data$plotdat, torn$data$plotdat)
#>   variable      value Level variable      value Level
#> 1      cyl -0.4467869  10th      cyl -0.7572746  10th
#> 2       wt  5.4838040  10th       wt  5.4998585  10th
#> 3       hp  3.7507816  10th       hp  3.9416010  10th
#> 4      cyl  0.3701948  90th      cyl  0.6274561  90th
#> 5       wt -3.6084234  90th       wt -3.6189876  90th
#> 6       hp -4.5003568  90th       hp -4.7293105  90th

gtest_wt <- lm(mpg ~ cyl*wt*hp, data = mtcars, weights = rep(1:2, nrow(mtcars) / 2))
torn_wt <- tornado(gtest_wt, type = "StdDev", alpha = 2)

gtest <- lm(mpg ~ cyl*wt*hp, data = mtcars)
torn <- tornado(gtest, type = "StdDev", alpha = 2)

cbind(torn_wt$data$plotdat, torn$data$plotdat)
#>   variable      value        Level variable     value        Level
#> 1      cyl -0.7295326 mean - 2*std      cyl -1.236510 mean - 2*std
#> 2       wt  8.5051219 mean - 2*std       wt  8.530022 mean - 2*std
#> 3       hp  6.3743292 mean - 2*std       hp  6.698620 mean - 2*std
#> 4      cyl  0.7295326 mean + 2*std      cyl  1.236510 mean + 2*std
#> 5       wt -8.5051219 mean + 2*std       wt -8.530022 mean + 2*std
#> 6       hp -6.3743292 mean + 2*std       hp -6.698620 mean + 2*std
ghobro commented 1 month ago

Apologies for the delay in getting back to you. Here's a reproducible example.

# Load necessary library
library(dplyr)

# Step 1: Create a simple dataset with unequal group sizes and different means for y
set.seed(123)
data <- data.frame(
  group = rep(letters[1:3], times = c(3, 5, 4)), # groups a, b, c with different sizes
  x = rep(c(0, 5, 10), times = c(3, 5, 4)), # repeated x values within each group
  y = c(rnorm(3, mean = 0), rnorm(5, mean = 5), rnorm(4, mean = 10))
)

# Step 2: Run a linear model on the original dataset
model <- lm(y ~ x, data = data)

# Step 3: Group by 'group' and get counts
grouped_data <- data %>%
  group_by(group) %>%
  summarise(
    x = mean(x),
    y = mean(y),
    count = n()
  )

# Step 4: Run a weighted linear model using the counts
weighted_model <- lm(y ~ x, data = grouped_data, weights = count)

# Step 5: Check the coefficients to ensure they are the same
coef(model)
# (Intercept)           x 
# 0.2741926   0.9852283 
coef(weighted_model)
# (Intercept)           x 
# 0.2741926   0.9852283 

We can see the coefficients of the models are the same, as expected. But when we look at the tornado sensitivity data, we get different results.

library(tornado)

# Create tornado from individual model
torn <- tornado(model, type = "percentiles", alpha = 0.10)

# Create tornado from weighed model
torn_weighted <- tornado(weighted_model, type = "percentiles", alpha = 0.10)

# Check the central values:
list(torn$data$pmeans[[1]], torn_weighted$data$pmeans[[1]])
# Not the same:
# [[1]]
# [1] 5.610846
# 
# [[2]]
# [1] 5.200334

But it would seem that they should be the same, and if we feed the weighted mean of the grouped data into the weighted model, we get what we would expect.

# We can see the central value in the weighted model is using the absolute mean
# of x 
predict(weighted_model, tibble(x = mean(grouped_data$x)))[[1]]
# [1] 5.200334

# But it would seem that the central values should be the same, given the 
# data feeding each model are essentially the same

# if we use the mean of the ungrouped data in the weighted model (which is the 
# same as the weighted mean of the grouped data), we get the same value:
predict(weighted_model, tibble(x = mean(data$x)))[[1]]
# [1] 5.610846
bertcarnell commented 1 month ago

I agree that the center of the tornado should be the prediction at the weighted mean of the input variables. Thanks for catching that, and I will fix it. I also need to think more about the tornado sensitivity endpoints for the same reason.

bertcarnell commented 1 month ago

I pushed the fix for this (https://github.com/bertcarnell/tornado/commit/405ec6ca8e87505fd10d1247e66c6fe3b219d5ee). You can pick up the development version and test. I'll submit to CRAN in a couple of days. I'd appreciate any feedback you have.

devtools::install_github("bertcarnell/tornado")
bertcarnell commented 1 month ago

I created a release for this version, but cannot submit to CRAN during their summer down time.

ghobro commented 1 month ago

Hello, I have tested and it now produces the expected results. Thank you.