mkomod / survival.svb

Variational Bayes for High-Dimensional Survival Analysis https://arxiv.org/abs/2112.10270
6 stars 2 forks source link

Question about distribution of `beta_hat` values #2

Open bblodfon opened 1 year ago

bblodfon commented 1 year ago

Hi @mkomod! Great work with this package!

I was trying to understand what exactly is the posterior distribution that the beta_hat values follow? There is an equation below Eg. (6) in the paper that relates to that I think. I see that beta_hat are the mean under the variational approximation (γ*μ) and there is an output s (sigma => stadard deviation) - maybe they follow a Gaussian with mean beta_hat and sd = γ*s? Then what about the δ_0 dirac term I see in the formula? The idea is to be able to make a credible interval for the beta_hat's or draw from the posterior distribution.

John

mkomod commented 1 year ago

Hi John,

Thanks for your question. The distribution of $\beta_j$ is as follows

$$\beta_j \overset{iid}{\sim} \gamma_j N(\mu_j, \sigma_j^2) + (1-\gamma_j) \delta_0$$

for $j=1,\dots,p$. Meaning, $\beta_j$ is a mixture of a Normal distribution with mean $\mu_j$ and variance $\sigma_j^2$ and a Dirac mass at 0. You are correct that you can transform the Normal distribution to have mean $\gamma_j \mu_j$ and variance $\gamma_j^2 \sigma^2$ however this would change the interpretation of the distribution. The variable $\gamma_j$ gives the posterior probability of inclusion - i.e. the probability the coefficient is non-zero - which is useful when performing variable selection (as a value close to one implies the coefficient is non-zero)

With respect to $\widehat{\beta}$ in the paper $\widehat{\beta}$ is the posterior mean, given by $\widehat{\beta}_j = \gamma_j \mu_j$. You can use this to compute point estimates if needed - i.e. if you'd like to estimate the survival score $\widehat{\eta} = X \widehat{\beta}$.

If however, you'd like to quantify the uncertainty about some value of interest e.g. $X \beta$, then the following code might be of use:

svb.sample <- function(fit, samples=1e4)
{
    p <- length(fit$g)

    beta <- replicate(samples, 
    {
    j <- runif(p) <= fit$g
    b <- rep(0, p)

    if (sum(j) == 0)
        return(b)

    m <- rnorm(sum(j), fit$mu[j], fit$s[j])
    b[j] <- m
    return(b)
    }, simplify="matrix")

    return(beta)
}

Please let me know if more clarification is needed?

Michael

p.s. From a recent project, I recommend orthogonalizing the design matrix in some way. This improves the stability of the algorithm. I'll dig up some code for this if needed?

bblodfon commented 1 year ago

Hi @mkomod,

Thanks for answering this!

I am coming from the ML side of things, so I would encourage to make also a predict(fit, newdata, ...) function, where users will have the option to supply new X data (test set) and get the predicted risk scores $\eta$ along with the option to get the full matrix (as in the example code above) so that uncertainty quantification can be assessed (which is the advantage of such Bayesian methods). It would be even more great if survival prediction is provided for the mean posterior betas (via the Breslow estimator).

This will make it easier to wrap your method in the mlr3 framework and be used by others. I am currently trying to include more Bayesian survival models, like survival BART for example. Maybe with a little bit of help from you, we can include survival.svb as well!

bblodfon commented 2 months ago

@mkomod Could we add a predict method? :)