PLNnetwork not incorporating matrix penalty_weights when fitting models

dcalderon commented 1 year ago

I'm trying to fit a PLNnetwork model with a matrix penalty through the control_main parameter, but for some reason the estimated Sigma matrices are identical to a model fit without the matrix penalty. My best guess is that the weights are somehow not getting into the glassoFast() call. Here's some code to reproduce the issue:

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_main = list(penalty_weights = W))

mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

# the sigma matrices are basically the same
mean(abs(sigma(mhp) - sigma(mwp)))
plot(sigma(mhp), sigma(mwp))

Here's the output of sessionInfo():

R version 4.1.2 (2021-11-01)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.6

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRlapack.dylib

[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] PLNmodels_0.11.7

jchiquet commented 1 year ago

Thanks for bringing the problem up, I'll look into it.

jchiquet commented 1 year ago

Ok, I get it. Two things

It is indeed confusing, so I am going to change the default behavior so that sigma() send back the regularized covariance when calling PLNnetwork. Sorry for the confusion.

So the following is basically working

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_init = list(penalty_weights = W))


mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

Sigma_reg_mhp <- solve(mhp$model_par$Omega)
Sigma_reg_mwp <- solve(mwp$model_par$Omega)

# the sigma matrices are slightly different
mean(abs(Sigma_reg_mhp - Sigma_reg_mwp))
plot(Sigma_reg_mhp, Sigma_reg_mwp)
dcalderon commented 1 year ago

Thanks for the fast and very helpful response! I was now able to get different model estimates of Sigma (from solve(Omega)) by including the penalty matrix in control_init. But, this only worked after I switched over to the dev branch (version 0.11.7-9600). Using the CRAN version of PLNmodels (version 0.11.7) I was still getting identical estimates. Sharing this just because I was still a bit confused at first, but from now I'll probably work from the dev branch. Thanks again!!

jchiquet commented 1 year ago

I merged this into master today, so now the following code gives what you expect :

require(PLNmodels); set.seed(42); data(trichoptera)
trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)

models_homegenous_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera)

# Matrix of random penalties
p <- ncol(trichoptera$Abundance); W <- diag(1, p, p)
W[upper.tri(W)] <- runif(p*(p-1)/2, min = 1, max = 5)
W[lower.tri(W)] <- t(W)[lower.tri(W)]

models_weighted_penalties <- PLNnetwork(Abundance ~ 1, data = trichoptera,
                                        control_init = list(penalty_weights = W))


mhp <- getBestModel(models_homegenous_penalties)
mwp <- getBestModel(models_weighted_penalties)

# the sigma matrices are slightly different
mean(abs(sigma(mhp)  - sigma(mwp)))
plot(sigma(mhp), sigma(mwp))
dcalderon commented 1 year ago

Great, thank you!