PLN-team / PLNmodels

A collection of Poisson lognormal models for multivariate count data analysis
https://pln-team.github.io/PLNmodels
GNU General Public License v3.0
54 stars 18 forks source link

PLNnetwork not incorporating matrix penalty_weights when fitting models #92

Closed dcalderon closed 1 year ago

dcalderon commented 1 year ago

I'm trying to fit a PLNnetwork model with a matrix penalty through the control_main parameter, but for some reason the estimated Sigma matrices are identical to a model fit without the matrix penalty. My best guess is that the weights are somehow not getting into the glassoFast() call. Here's some code to reproduce the issue:

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_main = list(penalty_weights = W))

mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

# the sigma matrices are basically the same
mean(abs(sigma(mhp) - sigma(mwp)))
plot(sigma(mhp), sigma(mwp))

Here's the output of sessionInfo():

R version 4.1.2 (2021-11-01)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.6

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] PLNmodels_0.11.7

loaded via a namespace (and not attached):
  [1] backports_1.3.0             Hmisc_4.6-0                 fastmatch_1.1-3             corrplot_0.90              
  [5] VGAM_1.1-6                  BiocFileCache_2.2.0         plyr_1.8.6                  igraph_1.2.7               
  [9] lazyeval_0.2.2              splines_4.1.2               BiocParallel_1.28.0         listenv_0.8.0              
 [13] GenomeInfoDb_1.30.0         ggplot2_3.3.6               digest_0.6.28               ensembldb_2.18.0           
 [17] htmltools_0.5.2             glassoFast_1.0              fansi_0.5.0                 magrittr_2.0.1             
 [21] checkmate_2.0.0             memoise_2.0.0               BSgenome_1.62.0             cluster_2.1.2              
 [25] globals_0.15.1              Biostrings_2.62.0           matrixStats_0.61.0          prettyunits_1.1.1          
 [29] jpeg_0.1-9                  colorspace_2.0-2            blob_1.2.2                  rappdirs_0.3.3             
 [33] xfun_0.27                   dplyr_1.0.7                 crayon_1.4.2                RCurl_1.98-1.5             
 [37] lme4_1.1-27.1               survival_3.2-13             VariantAnnotation_1.40.0    glue_1.6.2                 
 [41] gtable_0.3.0                zlibbioc_1.40.0             XVector_0.34.0              DelayedArray_0.20.0        
 [45] future.apply_1.8.1          SingleCellExperiment_1.16.0 BiocGenerics_0.40.0         scales_1.1.1               
 [49] cicero_1.3.6                DBI_1.1.1                   Signac_1.8.0                Rcpp_1.0.7                 
 [53] progress_1.2.2              htmlTable_2.3.0             foreign_0.8-81              bit_4.0.4                  
 [57] Formula_1.2-4               stats4_4.1.2                htmlwidgets_1.5.4           httr_1.4.2                 
 [61] RColorBrewer_1.1-2          ellipsis_0.3.2              pkgconfig_2.0.3             XML_3.99-0.8               
 [65] Gviz_1.38.0                 nnet_7.3-16                 dbplyr_2.1.1                utf8_1.2.2                 
 [69] tidyselect_1.1.1            rlang_1.0.2                 AnnotationDbi_1.56.1        munsell_0.5.0              
 [73] tools_4.1.2                 cachem_1.0.6                cli_3.3.0                   generics_0.1.1             
 [77] RSQLite_2.2.8               stringr_1.4.0               fastmap_1.1.0               yaml_2.2.1                 
 [81] knitr_1.36                  bit64_4.0.5                 purrr_0.3.4                 KEGGREST_1.34.0            
 [85] AnnotationFilter_1.18.0     pbapply_1.5-0               future_1.26.1               nlme_3.1-153               
 [89] monocle3_1.2.9              RcppRoll_0.3.0              xml2_1.3.2                  biomaRt_2.50.0             
 [93] compiler_4.1.2              rstudioapi_0.13             filelock_1.0.2              curl_4.3.2                 
 [97] png_0.1-7                   tibble_3.1.5                stringi_1.7.5               GenomicFeatures_1.46.1     
[101] lattice_0.20-45             ProtGenerics_1.26.0         Matrix_1.3-4                nloptr_1.2.2.3             
[105] vctrs_0.4.1                 pillar_1.6.4                lifecycle_1.0.1             data.table_1.14.2          
[109] bitops_1.0-7                irlba_2.3.3                 patchwork_1.1.1             rtracklayer_1.54.0         
[113] GenomicRanges_1.46.0        R6_2.5.1                    BiocIO_1.4.0                latticeExtra_0.6-29        
[117] gridExtra_2.3               IRanges_2.28.0              parallelly_1.32.0           codetools_0.2-18           
[121] dichromat_2.0-0             boot_1.3-28                 MASS_7.3-54                 assertthat_0.2.1           
[125] SummarizedExperiment_1.24.0 rjson_0.2.20                SeuratObject_4.0.2          GenomicAlignments_1.30.0   
[129] Rsamtools_2.10.0            S4Vectors_0.32.0            GenomeInfoDbData_1.2.7      parallel_4.1.2             
[133] hms_1.1.1                   lyon_0.0.1                  terra_1.5-34                grid_4.1.2                 
[137] rpart_4.1-15                tidyr_1.1.4                 minqa_1.2.4                 MatrixGenerics_1.6.0       
[141] biovizBase_1.42.0           Biobase_2.54.0              base64enc_0.1-3             restfulr_0.0.13
jchiquet commented 1 year ago

Thanks for bringing the problem up, I'll look into it.

jchiquet commented 1 year ago

Ok, I get it. Two things

It is indeed confusing, so I am going to change the default behavior so that sigma() send back the regularized covariance when calling PLNnetwork. Sorry for the confusion.

So the following is basically working

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_init = list(penalty_weights = W))

plot(models_homegenous_penalties)
plot(models_weighted_penalties)

mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

Sigma_reg_mhp <- solve(mhp$model_par$Omega)
Sigma_reg_mwp <- solve(mwp$model_par$Omega)

# the sigma matrices are slightly different
mean(abs(Sigma_reg_mhp - Sigma_reg_mwp))
plot(Sigma_reg_mhp, Sigma_reg_mwp)
dcalderon commented 1 year ago

Thanks for the fast and very helpful response! I was now able to get different model estimates of Sigma (from solve(Omega)) by including the penalty matrix in control_init. But, this only worked after I switched over to the dev branch (version 0.11.7-9600). Using the CRAN version of PLNmodels (version 0.11.7) I was still getting identical estimates. Sharing this just because I was still a bit confused at first, but from now I'll probably work from the dev branch. Thanks again!!

jchiquet commented 1 year ago

I merged this into master today, so now the following code gives what you expect :

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_init = list(penalty_weights = W))

plot(models_homegenous_penalties)
plot(models_weighted_penalties)

mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

# the sigma matrices are slightly different
mean(abs(sigma(mhp)  - sigma(mwp)))
plot(sigma(mhp), sigma(mwp))
dcalderon commented 1 year ago

Great, thank you!