Open njtierney opened 3 years ago
Issue with trying to get the predictions out:
library(yahtsee)
#> Loading required package: tsibble
#>
#> Attaching package: 'tsibble'
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, union
The helpfile for the “summary.linear.predictor” in inla says:
A matrix containing the mean and sd (plus, possibly quantiles and cdf) of the linear predictors η in the model
The formula used to fit the model is:
# hts_example_model <- fit_hts(
# #inputs are the levels of hierarchy, in order of decreasing size
# formula = pr ~ avg_lower_age + hts(who_subregion, country),
# .data = malaria_africa_ts,
# family = "gaussian",
# special_index = month_num
# )
The number of rows of the predictors aren’t the same as the number of rows in the data?
dim(hts_example_model$summary.linear.predictor)
#> [1] 3116 7
dim(malaria_africa_ts)
#> [1] 1046 15
exploring further
head(hts_example_model$summary.linear.predictor)
#> mean sd 0.025quant 0.5quant 0.975quant mode
#> APredictor.0001 0.3252638 0.09400785 0.14029390 0.3253923 0.5094822 0.3256548
#> APredictor.0002 0.2416014 0.05435747 0.13518114 0.2415354 0.3484451 0.2414064
#> APredictor.0003 0.3139412 0.06155918 0.19373684 0.3137367 0.4353352 0.3133427
#> APredictor.0004 0.1825636 0.05737481 0.06998894 0.1825701 0.2950798 0.1825834
#> APredictor.0005 0.1831001 0.05826603 0.06873181 0.1831166 0.2973431 0.1831501
#> APredictor.0006 0.2959515 0.06926421 0.16066124 0.2957178 0.4325799 0.2952749
#> kld
#> APredictor.0001 6.594523e-06
#> APredictor.0002 7.115063e-07
#> APredictor.0003 6.288670e-07
#> APredictor.0004 3.922522e-07
#> APredictor.0005 5.845695e-07
#> APredictor.0006 1.300368e-06
tail(hts_example_model$summary.linear.predictor)
#> mean sd 0.025quant 0.5quant 0.975quant
#> Predictor.2065 -0.03899639 0.11579578 -0.2697727 -0.03788110 0.1855605
#> Predictor.2066 -0.03925871 0.11709601 -0.2726995 -0.03810259 0.1877585
#> Predictor.2067 -0.03849500 0.11819788 -0.2741148 -0.03734011 0.1906912
#> Predictor.2068 -0.03781523 0.11924636 -0.2755223 -0.03665615 0.1934664
#> Predictor.2069 -0.03735069 0.12027057 -0.2771224 -0.03617712 0.1959417
#> Predictor.2070 0.25135076 0.05013947 0.1441325 0.25334592 0.3464479
#> mode kld
#> Predictor.2065 -0.03556253 1.833651e-05
#> Predictor.2066 -0.03570182 1.844227e-05
#> Predictor.2067 -0.03494526 1.824962e-05
#> Predictor.2068 -0.03425564 1.807765e-05
#> Predictor.2069 -0.03374920 1.792919e-05
#> Predictor.2070 0.25622487 2.820810e-05
head(rownames(hts_example_model$summary.linear.predictor))
#> [1] "APredictor.0001" "APredictor.0002" "APredictor.0003" "APredictor.0004"
#> [5] "APredictor.0005" "APredictor.0006"
tail(rownames(hts_example_model$summary.linear.predictor))
#> [1] "Predictor.2065" "Predictor.2066" "Predictor.2067" "Predictor.2068"
#> [5] "Predictor.2069" "Predictor.2070"
again, the “summary.fitted.values” are not the same dimensions as the data. Here is the description from the helpfile
A matrix containing the mean and sd (plus, possibly quantiles and cdf) of the fitted values g^{-1}(η) obtained by transforming the linear predictors by the inverse of the link function. This quantity is only computed if marginals.fitted.values is computed. Note that if an observation is NA then the identity link is used. You can manually transform a marginal using inla.marginal.transform() or set the argument link in the control.predictor-list; see ?control.predictor
dim(hts_example_model$summary.fitted.values)
#> [1] 3116 6
dim(malaria_africa_ts)
#> [1] 1046 15
head(hts_example_model$summary.fitted.values)
#> mean sd 0.025quant 0.5quant 0.975quant
#> fitted.APredictor.0001 0.3252638 0.09400786 0.14029390 0.3253923 0.5094822
#> fitted.APredictor.0002 0.2416014 0.05435746 0.13518114 0.2415354 0.3484451
#> fitted.APredictor.0003 0.3139412 0.06155918 0.19373684 0.3137367 0.4353352
#> fitted.APredictor.0004 0.1825636 0.05737480 0.06998894 0.1825701 0.2950798
#> fitted.APredictor.0005 0.1831001 0.05826603 0.06873181 0.1831166 0.2973431
#> fitted.APredictor.0006 0.2959515 0.06926421 0.16066124 0.2957178 0.4325799
#> mode
#> fitted.APredictor.0001 0.3256548
#> fitted.APredictor.0002 0.2414064
#> fitted.APredictor.0003 0.3133426
#> fitted.APredictor.0004 0.1825834
#> fitted.APredictor.0005 0.1831501
#> fitted.APredictor.0006 0.2952749
tail(hts_example_model$summary.fitted.values)
#> mean sd 0.025quant 0.5quant 0.975quant
#> fitted.Predictor.2065 -0.03899639 0.11579578 -0.2697727 -0.03788110 0.1855605
#> fitted.Predictor.2066 -0.03925871 0.11709600 -0.2726995 -0.03810259 0.1877585
#> fitted.Predictor.2067 -0.03849500 0.11819789 -0.2741148 -0.03734011 0.1906912
#> fitted.Predictor.2068 -0.03781523 0.11924636 -0.2755223 -0.03665615 0.1934664
#> fitted.Predictor.2069 -0.03735069 0.12027057 -0.2771224 -0.03617712 0.1959417
#> fitted.Predictor.2070 0.25135075 0.05013947 0.1441325 0.25334592 0.3464479
#> mode
#> fitted.Predictor.2065 -0.03556253
#> fitted.Predictor.2066 -0.03570181
#> fitted.Predictor.2067 -0.03494526
#> fitted.Predictor.2068 -0.03425564
#> fitted.Predictor.2069 -0.03374920
#> fitted.Predictor.2070 0.25622488
head(rownames(hts_example_model$summary.fitted.values))
#> [1] "fitted.APredictor.0001" "fitted.APredictor.0002" "fitted.APredictor.0003"
#> [4] "fitted.APredictor.0004" "fitted.APredictor.0005" "fitted.APredictor.0006"
tail(rownames(hts_example_model$summary.fitted.values))
#> [1] "fitted.Predictor.2065" "fitted.Predictor.2066" "fitted.Predictor.2067"
#> [4] "fitted.Predictor.2068" "fitted.Predictor.2069" "fitted.Predictor.2070"
Created on 2021-06-30 by the reprex package (v2.0.0)
See : https://inbo.github.io/tutorials/tutorials/r_inla/inlabru.pdf for details on prediction
My descent into madness regarding getting predictions out - it appears the missing link is the "formula" argument of "predict".
However, the predictions are terrible!
library(yahtsee)
#> Loading required package: tsibble
#>
#> Attaching package: 'tsibble'
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, union
library(tidyverse)
malaria_africa_ts_subset <- malaria_africa_ts %>%
filter(who_subregion %in% c("AFRO-W", "AFRO-C")) %>%
group_by(who_subregion) %>%
filter(country %in% c("Angola",
"Benin",
"Cameroon",
"Cape Verde",
"Central African Republic",
"Gabon",
"Gambia",
"Guinea-Bissau",
"Nigeria",
"Togo"))
malaria_africa_ts_subset
#> # A tsibble: 186 x 15 [1D]
#> # Key: country [46]
#> # Groups: who_subregion [2]
#> who_region who_subregion country date month_num positive examined
#> <fct> <fct> <fct> <date> <dbl> <dbl> <int>
#> 1 AFRO AFRO-W Angola 1989-06-01 120 15.8 50
#> 2 AFRO AFRO-W Angola 2005-11-01 372 82 111
#> 3 AFRO AFRO-W Angola 2006-04-01 300 102 197
#> 4 AFRO AFRO-W Angola 2006-11-01 384 41 347
#> 5 AFRO AFRO-W Angola 2006-12-01 396 173 734
#> 6 AFRO AFRO-W Angola 2007-01-01 276 216 828
#> 7 AFRO AFRO-W Angola 2007-02-01 288 42 71
#> 8 AFRO AFRO-W Angola 2007-03-01 300 119 448
#> 9 AFRO AFRO-W Angola 2011-01-01 324 1 239
#> 10 AFRO AFRO-W Angola 2011-02-01 336 148 1132
#> # … with 176 more rows, and 8 more variables: pr <dbl>, avg_lower_age <dbl>,
#> # continent_id <fct>, country_id <fct>, year <int>, month <int>,
#> # avg_upper_age <dbl>, species <fct>
model_yah <- fit_hts(
#inputs are the levels of hierarchy, in order of decreasing size
formula = pr ~ avg_lower_age + hts(who_subregion, country),
.data = malaria_africa_ts_subset,
family = "gaussian",
special_index = month_num
)
#> ℹ Fitting model with inlabru
#> ✓ Fitting model with inlabru ... done
#>
model_yah
#> <hts_inla> model (fit in 12.15s)
#> Formula:
#> • ~
#> • pr
#> • avg_lower_age + hts(who_subregion, country)
#>
model_predictions <- predict(object = model_yah)
tibble(model_predictions$Predictor)
#> # A tibble: 506 × 9
#> mean sd q0.025 median q0.975 smin smax cv var
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.00270 0.00641 -0.0105 0.00263 0.0153 -0.0144 0.0218 2.37 4.11e-5
#> 2 0.00151 0.0118 -0.0149 0.000890 0.0271 -0.0386 0.0695 7.81 1.39e-4
#> 3 -0.00151 0.0136 -0.0245 -0.00124 0.0224 -0.0693 0.0315 -8.98 1.84e-4
#> 4 0.000653 0.0102 -0.0169 0.0000251 0.0228 -0.0304 0.0297 15.6 1.03e-4
#> 5 -0.000666 0.0122 -0.0254 -0.000188 0.0211 -0.0587 0.0410 -18.4 1.50e-4
#> 6 0.00116 0.0102 -0.0200 0.00129 0.0219 -0.0321 0.0328 8.81 1.04e-4
#> 7 0.00000436 0.0125 -0.0187 -0.000740 0.0305 -0.0376 0.0514 2876. 1.57e-4
#> 8 -0.000355 0.0135 -0.0233 0.000233 0.0216 -0.0816 0.0499 -38.2 1.84e-4
#> 9 0.000158 0.0124 -0.0252 0.000206 0.0229 -0.0566 0.0430 78.7 1.54e-4
#> 10 0.00183 0.00975 -0.0153 0.000566 0.0240 -0.0214 0.0305 5.33 9.51e-5
#> # … with 496 more rows
malaria_africa_ts_subset$pr
#> [1] 0.31500000 0.73873874 0.51776650 0.11815562 0.23569482 0.26086957
#> [7] 0.59154930 0.26562500 0.00418410 0.13074205 0.10107198 0.04938272
#> [13] 0.18803419 0.04152824 0.14010989 0.19752066 0.22789539 0.15978129
#> [19] 0.22033898 0.19549550 0.27333333 0.85156250 0.05177112 0.14133333
#> [25] 0.21775899 0.21913580 0.31430446 0.27266963 0.42792793 0.28571429
#> [31] 0.12759644 0.18623025 0.19940476 0.17826087 0.38565629 0.12719298
#> [37] 0.06338028 0.33514493 0.40935673 0.54516129 0.81521739 0.92553191
#> [43] 0.76826722 0.54285714 0.14071511 0.43750000 0.83333333 0.65517241
#> [49] 0.28000000 0.18439716 0.11728395 0.23529412 0.40000000 0.30188679
#> [55] 0.63461538 0.61488673 0.55026455 0.05181347 0.26236559 0.52083333
#> [61] 0.51908397 0.57458564 0.67965368 0.31967213 0.53614458 0.47989950
#> [67] 0.32532751 0.44531250 0.60680529 0.42635659 0.64885496 0.20000000
#> [73] 0.16666667 1.00000000 0.02222222 0.08438819 0.21402660 0.21590909
#> [79] 0.29256360 0.33841132 0.19723866 0.36000000 0.65116279 0.00000000
#> [85] 0.00000000 0.19590643 0.42976939 0.75862069 0.52525253 0.28279570
#> [91] 0.28382353 0.14583333 0.32647059 0.10086455 0.37974684 0.55263158
#> [97] 0.36111111 0.41401274 0.63861386 0.19018405 0.65088757 0.67213115
#> [103] 0.37704918 0.28688525 0.25405405 0.26966292 0.26146789 0.14117647
#> [109] 0.25431034 0.21379310 0.51937984 0.08510638 0.51242236 0.33091787
#> [115] 0.23275862 0.52700000 0.24900000 0.29411765 0.37786260 0.02072539
#> [121] 0.04676259 0.19583333 0.18000000 0.24120603 0.42180095 0.36005314
#> [127] 0.67256637 0.37073171 0.35958904 0.62039660 0.11658456 0.23916811
#> [133] 0.19354839 0.17035775 0.36641221 0.14917127 0.14609053 0.12430939
#> [139] 0.32136752 0.33333333 0.27813505 0.23491379 0.46428571 0.59154930
#> [145] 0.03846154 0.02718447 0.03956835 0.21003717 0.69090909 0.03301887
#> [151] 0.48669426 0.03105590 0.70178282 0.55833333 0.04000000 0.52229299
#> [157] 0.00000000 0.04599212 0.50000000 0.51977401 0.36054422 0.82524272
#> [163] 0.51851852 0.29118774 0.70879121 0.80000000 0.71366594 0.12793177
#> [169] 0.41497976 0.33303167 0.33509700 0.38405797 0.27071823 0.23970944
#> [175] 0.33097595 0.80710660 0.60957179 0.05431310 0.34669556 0.48020833
#> [181] 0.48000000 0.41047297 0.23809524 0.08552632 0.29964695 0.15384615
# it says that there are 506 values - why are there 506 values from predictions
# if there are only
new_preds <- malaria_africa_ts_subset %>%
predict(object = model_yah, formula = ~ avg_lower_age + who_subregion + country)
new_preds
#> # A tsibble: 186 x 24 [1D]
#> # Key: country [46]
#> # Groups: who_subregion [2]
#> who_region who_subregion country date month_num positive examined
#> <fct> <fct> <fct> <date> <dbl> <dbl> <int>
#> 1 AFRO AFRO-W Angola 1989-06-01 120 15.8 50
#> 2 AFRO AFRO-W Angola 2005-11-01 372 82 111
#> 3 AFRO AFRO-W Angola 2006-04-01 300 102 197
#> 4 AFRO AFRO-W Angola 2006-11-01 384 41 347
#> 5 AFRO AFRO-W Angola 2006-12-01 396 173 734
#> 6 AFRO AFRO-W Angola 2007-01-01 276 216 828
#> 7 AFRO AFRO-W Angola 2007-02-01 288 42 71
#> 8 AFRO AFRO-W Angola 2007-03-01 300 119 448
#> 9 AFRO AFRO-W Angola 2011-01-01 324 1 239
#> 10 AFRO AFRO-W Angola 2011-02-01 336 148 1132
#> # … with 176 more rows, and 17 more variables: pr <dbl>, avg_lower_age <dbl>,
#> # continent_id <fct>, country_id <fct>, year <int>, month <int>,
#> # avg_upper_age <dbl>, species <fct>, mean <dbl>, sd <dbl>, q0.025 <dbl>,
#> # median <dbl>, q0.975 <dbl>, smin <dbl>, smax <dbl>, cv <dbl>, var <dbl>
ggplot(new_preds,
aes(x = date,
y = mean,
group = country)) +
geom_line() +
geom_point(aes(y = pr)) +
facet_wrap(~country)
Created on 2022-01-27 by the reprex package (v2.0.1)
Can we add arguments to
predict
to also return standard error estimates and friends?