PLN-team / PLNmodels

A collection of Poisson lognormal models for multivariate count data analysis
https://pln-team.github.io/PLNmodels
GNU General Public License v3.0
54 stars 18 forks source link

torch backend crashing where nlopt succeeds? #107

Closed ctrapnell closed 12 months ago

ctrapnell commented 1 year ago

Hi PLN-team,

We are using this package quite a bit and love it. We are interested in the torch backend and potentially accelerating on GPUs. However, when we run some of the torch tests in torch_PLN.R, we run into errors. For example:

system.time(myPLN_nlopt <- PLN(Abundance ~ 1 + offset(log(Offset)), data = oaks, control = PLN_param(backend = "nlopt", covariance = "diagonal")) )

works:

Initialization... Adjusting a diagonal covariance PLN model with nlopt optimizer Post-treatments... DONE! user system elapsed 1.192 1.686 1.490

But:

> system.time(myPLN_torch <- PLN(Abundance ~ 1 + offset(log(Offset)), data = oaks, control = PLN_param(backend = "torch", covariance = "diagonal")) )

Doesn't:

Initialization... Adjusting a diagonal covariance PLN model with torch optimizerError in if (delta_f < config$ftol_rel) status <- 3 : missing value where TRUE/FALSE needed Timing stopped at: 0.014 0.002 0.015

Any ideas what might be going on here?

Thank you!

jchiquet commented 1 year ago

Hi @ctrapnell , thanks a lot for your feedback, happy and honored to know that you are using PLNmodels in your lab!

The torch backend has not been extensively checked, it still under development. I'll take a look as soon as possible, especially if there's a need (and it seems that Manhendra is looking at it too).

By the way, if you're looking for a Python implementation of PLNmodels that uses GPUs via Pytorch, we have a Python version in development, for the standard PLN model and the ACP version. More to come soon. https://pypi.org/project/pyPLNmodels/

mahendra-mariadassou commented 1 year ago

Hi,

This is usually the sign of a convergence problem. In this case, it stems from the objective exploding to infinity and the quantity delta_f being NaN at some point (which gives the error message). You can solve it using a smaller learning rate, like 0.01 instead of 0.1 in PLN_param(), (which is not well documented in the help, we need to fix that).

library(PLNmodels)
#> This is packages 'PLNmodels' version 1.0.4-0100
#> Use future::plan(multicore/multisession) to speed up PLNPCA/PLNmixture/stability_selection.
data("oaks")
system.time(myPLN_nlopt <-
              PLN(Abundance ~ 1  + offset(log(Offset)),
                  data = oaks, control = PLN_param(backend = "nlopt"))
)
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with nlopt optimizer
#>  Post-treatments...
#>  DONE!
#> utilisateur     système      écoulé 
#>      28.274      15.376      12.128
system.time(myPLN_torch <-
              PLN(Abundance ~ 1  + offset(log(Offset)),
                  data = oaks, control = PLN_param(backend = "torch", config_optim = list(lr = 0.01)))
)
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with torch optimizer
#>  Post-treatments...
#>  DONE!
#> utilisateur     système      écoulé 
#>      29.274      13.488      11.681

Created on 2023-09-20 with reprex v2.0.2