CDCgov / cfa-viral-lineage-model

Apache License 2.0
10 stars 0 forks source link

Unstandardized time #66

Closed afmagee42 closed 1 day ago

afmagee42 commented 3 weeks ago

The goal of this PR was to have models that work on unstandardized time with appropriate priors. This entailed two interrelated matters.

Models and MCMC

I un-standardized time, resolving #47. I also added an abstract class from which all models now descend. Models now additionally know some things about how their MCMC should be run and diagnosed.

MCMC

I altered our MCMC settings and convergence diagnostic thresholds. We now

Mass matrices

I re-wrote the IndependentDivisionsModel and the HierarchicalDivisionsModel to use a loop over states. This significantly slows model compilation, but in return it allows our models to mix sufficiently (ESS per sample is drastically higher and lower variance in testing).

The mixing gain is because this loop allows us to use a block-diagonal mass matrix, which adapts to correlations in each state in the slopes and intercepts. A full mass matrix worked well most of the time but occasionally crashed entire chains (this sunk #58). This simpler structure appears to be learnable and capture sufficient correlation even in the hierarchical model.

Via .dense_mass(), each model can now return a value appropriate for passing to the dense_mass argument of NumPyro's NUTS implementation.

Other model changes

I simplified the HierarchicalDivisionsModel to remove the multivariate distribution on how slopes deviate from the overall mean term (slopes, like intercepts, now deviate independently across variants). This was killing MCMC performance and preliminary testing suggests that in general the models like very similar slopes across regions anyways. A loop-version of the previous model is available as the CorrelatedDeviationsModel, but it is not enabled in the configs in retrospective-forecasting.

I also made it easier to change the "strength" of pooling as a user setting, separately for slopes and intercepts. This term is in [0,1] and allocates variance to the global mean term, with one minus this amount going to local deviations around that mean. Thus, 0 is complete independence, and 1 is complete pooling. I left the intercept at the previous implicit value of 0.5, but increased the slope default to 0.75, as preliminary testing suggested larger values are appropriate.

Pipeline

The increased model iterations provoked a memory leak, or something that looks a lot like one. This caused both evaluation and plotting to fail in retrospective-forecasting. It was apparently induced by the plotting but interacted with upstream evaluation in polars dataframes in ways I cannot comprehend.

To keep our ability to plot things, I therefore had to refactor the pipeline. The changes to main look bigger than they are, the short version is that, in order to push all the plotting to separate system python calls, I

Recreational out-of-scope type hinting

I cleaned up a handful of places where type-hinting was either wrong, or unhappy.